• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import enum
17import inspect
18from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union, overload
19
20import numpy as np
21
22from . import ops
23from . import jax_jit
24from . import outfeed_receiver
25from . import pmap_lib
26from . import profiler
27from . import pytree
28
29_LiteralSlice = Any
30_Status = Any
31_Dtype = Any
32_XlaOpMetadata = Any
33
34_T = TypeVar("_T")
35
36class PrimitiveType(enum.IntEnum):
37  PRIMITIVE_TYPE_INVALID: PrimitiveType
38  PRED: PrimitiveType
39  S8: PrimitiveType
40  S16: PrimitiveType
41  S32: PrimitiveType
42  S64: PrimitiveType
43  U8: PrimitiveType
44  U16: PrimitiveType
45  U32: PrimitiveType
46  U64: PrimitiveType
47  BF16: PrimitiveType
48  F16: PrimitiveType
49  F32: PrimitiveType
50  F64: PrimitiveType
51  C64: PrimitiveType
52  C128: PrimitiveType
53  TUPLE: PrimitiveType
54  OPAQUE_TYPE: PrimitiveType
55  TOKEN: PrimitiveType
56
57def bfloat16_dtype() -> Type[Any]: ...
58
59# === BEGIN xla_compiler.cc
60
61class Shape:
62  def __init__(self, s: str): ...
63  @staticmethod
64  def tuple_shape(shapes: Sequence[Shape]) -> Shape: ...
65  @staticmethod
66  def array_shape(
67      type: Union[np.dtype, PrimitiveType],
68      dims_seq: Any = ...,
69      layout_seq: Any = ...,
70      dynamic_dimensions: Optional[List[bool]] = ...) -> Shape: ...
71  @staticmethod
72  def token_shape() -> Shape: ...
73  @staticmethod
74  def scalar_shape(type: Union[np.dtype, PrimitiveType]) -> Shape: ...
75  def dimensions(self) -> Tuple[int, ...]: ...
76  def xla_element_type(self) -> PrimitiveType: ...
77  def element_type(self) -> np.dtype: ...
78  def numpy_dtype(self) -> np.dtype: ...
79  def is_tuple(self) -> bool: ...
80  def is_array(self) -> bool: ...
81  def is_token(self) -> bool: ...
82  def is_static(self) -> bool: ...
83  def is_dynamic(self) -> bool: ...
84  def is_dynamic_dimension(self, dimension: int) -> bool: ...
85  def set_dynamic_dimension(self, dimension: int, is_dynamic: bool) -> None: ...
86  def rank(self) -> int: ...
87  def to_serialized_proto(self) -> bytes: ...
88  def tuple_shapes(self) -> List[Shape]: ...
89  def leaf_count(self) -> int: ...
90  def with_major_to_minor_layout_if_absent(self) -> Shape: ...
91  def __eq__(self, other: Shape) -> bool: ...
92  def __ne__(self, other: Shape) -> bool: ...
93  def __hash__(self) -> int: ...
94  def __repr__(self) -> str: ...
95
96class ProgramShape:
97  def __init__(self, params: Sequence[Shape], result: Shape) -> None: ...
98  def parameter_shapes(self) -> List[Shape]: ...
99  def result_shape(self) -> Shape: ...
100  def __repr__(self) -> str: ...
101
102class ShapeIndex:
103  def __init__(self, indices: List[int]) -> ShapeIndex: ...
104  def __eq__(self, other: Shape) -> bool: ...
105  def __ne__(self, other: Shape) -> bool: ...
106  def __hash__(self) -> int: ...
107  def __repr__(self) -> str: ...
108
109class Literal:
110  def __repr__(self) -> str: ...
111
112class XlaComputation:
113  def __init__(self, serialized_hlo_module_proto: bytes) -> None: ...
114  def get_hlo_module(self) -> HloModule: ...
115  def program_shape(self) -> ProgramShape: ...
116  def as_serialized_hlo_module_proto(self) -> bytes: ...
117  def as_hlo_text(self) -> str: ...
118  def as_hlo_dot_graph(self) -> str: ...
119  def hash(self) -> int: ...
120  def as_hlo_module(elf) -> HloModule: ...
121
122class HloPrintOptions:
123  def __init__(self) -> None: ...
124  @staticmethod
125  def short_parsable() -> HloPrintOptions: ...
126  @staticmethod
127  def canonical() -> HloPrintOptions: ...
128  @staticmethod
129  def fingerprint() -> HloPrintOptions: ...
130  print_large_constants: bool
131  print_metadata: bool
132  print_backend_config: bool
133  print_result_shape: bool
134  print_operand_shape: bool
135  print_operand_names: bool
136  print_ids: bool
137  print_extra_attributes: bool
138  print_program_shape: bool
139  print_percent: bool
140  print_control_dependencies: bool
141  compact_operands: bool
142  include_layout_in_shapes: bool
143  canonicalize_instruction_names: bool
144  canonicalize_computations: bool
145  indent_amount: int
146  is_in_nested_computation: bool
147  leading_and_trailing_instructions_number: int
148
149class HloModule:
150  def to_string(self, options: HloPrintOptions = ...) -> str: ...
151
152def hlo_module_to_dot_graph(hlo_module: HloModule) -> str: ...
153
154def hlo_module_cost_analysis(
155    client: Client,
156    module: HloModule) -> Dict[str, float]: ...
157
158class XlaOp: ...
159
160class XlaBuilder:
161  def __init__(self, name: str) -> None: ...
162  def Build(self, root: Optional[XlaOp] = ...) -> XlaComputation: ...
163  def GetShape(self, __op: XlaOp) -> Shape: ...
164  build = Build
165  def clear_op_metadata(self) -> None: ...
166  get_shape = GetShape
167  def get_program_shape(self, root: Optional[XlaOp] = ...) -> ProgramShape: ...
168  def is_constant(self, __op: XlaOp) -> bool: ...
169  def set_op_metadata(self, metadata: _XlaOpMetadata) -> None: ...
170  def set_sharding(self, sharding: OpSharding_Type) -> None: ...
171  def clear_sharding(self) -> None: ...
172  def setup_alias(
173      self,
174      __output_index: Sequence[int],
175      __param_number: int,
176      __param_index: Sequence[int]) -> None: ...
177
178class DeviceAssignment:
179  @staticmethod
180  def create(array: np.ndarray) -> DeviceAssignment: ...
181  def replica_count(self) -> int: ...
182  def computation_count(self) -> int: ...
183  def __repr__(self) -> str: ...
184  def serialize(self) -> bytes: ...
185
186class CompileOptions:
187  def __init__(self) -> None: ...
188  argument_layouts: Optional[List[Shape]]
189  parameter_is_tupled_arguments: bool
190  executable_build_options: ExecutableBuildOptions
191  tuple_arguments: bool
192  num_replicas: int
193  num_partitions: int
194  device_assignment: Optional[DeviceAssignment]
195
196def register_custom_call_target(fn_name: str, capsule: Any, platform: str) -> _Status: ...
197
198class DebugOptions:
199  def __repr__(self) -> str: ...
200  xla_cpu_enable_fast_math: bool
201  xla_cpu_fast_math_honor_infs: bool
202  xla_cpu_fast_math_honor_nans: bool
203  xla_cpu_fast_math_honor_division: bool
204  xla_cpu_fast_math_honor_functions: bool
205  xla_gpu_enable_fast_min_max: bool
206  xla_backend_optimization_level: int
207  xla_cpu_enable_xprof_traceme: bool
208  xla_llvm_disable_expensive_passes: bool
209  xla_test_all_input_layouts: bool
210
211class ExecutableBuildOptions:
212  def __init__(self) -> None: ...
213  def __repr__(self) -> str: ...
214  result_layout: Optional[Shape]
215  num_replicas: int
216  num_partitions: int
217  debug_options: DebugOptions
218  device_assignment: Optional[DeviceAssignment]
219  use_spmd_partitioning: bool
220
221class PrecisionConfig_Precision(enum.IntEnum):
222  DEFAULT: int
223  HIGH: int
224  HIGHEST: int
225
226class OpSharding_Type(enum.IntEnum):
227  REPLICATED: int
228  MAXIMAL: int
229  TUPLE: int
230  OTHER: int
231
232class ChannelHandle_ChannelType(enum.IntEnum):
233  CHANNEL_TYPE_INVALID: int
234  DEVICE_TO_DEVICE: int
235  DEVICE_TO_HOST: int
236  HOST_TO_DEVICE: int
237
238class ChannelHandle:
239  type: ChannelHandle_ChannelType
240  handle: int
241  def __repr__(self) -> str: ...
242
243class FftType(enum.IntEnum):
244  FFT: int
245  IFFT: int
246  RFFT: int
247  IRFFT: int
248
249# === END xla_compiler.cc
250
251class Device:
252  id: int
253  host_id: int
254  process_index: int
255  platform: str
256  device_kind: str
257  client: Client
258  def __str__(self) -> str: ...
259  def transfer_to_infeed(self, literal: _LiteralSlice): ...
260  def transfer_from_outfeed(self, shape: Shape): ...
261  def live_buffers(self) -> List[Buffer]: ...
262
263class CpuDevice(Device):
264  def __repr__(self) -> str: ...
265
266class GpuDevice(Device):
267  device_vendor: str
268  def __repr__(self) -> str: ...
269
270class TpuDevice(Device):
271  coords: Tuple[int, ...]
272  core_on_chip: int
273  def __repr__(self) -> str: ...
274
275class _GpuAllocatorKind(enum.IntEnum):
276    DEFAULT: int
277    PLATFORM: int
278    BFC: int
279    CUDA_ASYNC: int
280
281class GpuAllocatorConfig:
282  # TODO(b/194673104): Remove once pytype correctly resolves a nested enum.
283  Kind = _GpuAllocatorKind
284
285  def __init__(
286      self,
287      kind: _GpuAllocatorKind = ...,
288      memory_fraction: float = ...,
289      preallocate: bool = ...) -> None: ...
290
291class HostBufferSemantics(enum.IntEnum):
292  IMMUTABLE_ONLY_DURING_CALL: HostBufferSemantics
293  IMMUTABLE_UNTIL_TRANSFER_COMPLETES: HostBufferSemantics
294  ZERO_COPY: HostBufferSemantics
295
296class Client:
297  platform: str
298  platform_version: str
299  runtime_type: str
300  def device_count(self) -> int: ...
301  def local_device_count(self) -> int: ...
302  def devices(self) -> List[Device]: ...
303  def local_devices(self) -> List[Device]: ...
304  def live_buffers(self) -> List[Buffer]: ...
305  def live_executables(self) -> List[Executable]: ...
306  def host_id(self) -> int: ...
307  def process_index(self) -> int: ...
308  @overload
309  def get_default_device_assignment(
310      self,
311      num_replicas: int,
312      num_partitions: int) -> List[List[Device]]: ...
313  @overload
314  def get_default_device_assignment(
315      self,
316      num_replicas: int) -> List[Device]: ...
317  def create_channel_handle(self) -> ChannelHandle: ...
318  def create_device_to_host_channel_handle(self) -> ChannelHandle: ...
319  def create_host_to_device_channel_handle(self) -> ChannelHandle: ...
320  def buffer_from_pyval(
321      self,
322      argument: Any,
323      device: Device = ...,
324      force_copy: bool = ...,
325      host_buffer_semantics: HostBufferSemantics = ...) -> Buffer: ...
326  def compile(
327      self,
328      computation: XlaComputation,
329      compile_options: CompileOptions = ...) -> Executable: ...
330  def serialize_executable(self, executable: Executable) -> bytes: ...
331  def deserialize_executable(
332      self, serialized: bytes,
333      options: CompileOptions) -> Executable: ...
334  # TODO(skyewm): remove when jax stop providing hlo_module
335  def deserialize_executable(
336      self, serialized: bytes,
337      hlo_module: HloModule,
338      options: CompileOptions) -> Executable: ...
339  def heap_profile(self) -> bytes: ...
340  def defragment(self) -> _Status: ...
341  def emit_python_callback(
342      self, callable: Callable, builder: XlaBuilder, operands: Sequence[XlaOp],
343      results_shapes: Sequence[Shape],
344      operand_layouts: Optional[Sequence[Shape]] = ...,
345      has_side_effects: bool = ...) -> Tuple[XlaOp, Any]: ...
346
347
348def get_cpu_client(asynchronous: bool = ...) -> Client: ...
349def get_tfrt_cpu_client(asynchronous: bool = ...) -> Client: ...
350def get_interpreter_client() -> Client: ...
351def get_gpu_client(
352    asynchronous: bool = ...,
353    allocator_config: GpuAllocatorConfig = ...,
354    distributed_client: Optional[DistributedRuntimeClient] = ...,
355    node_id: int = ...) -> Client:...
356def get_tpu_client(max_inflight_computations: int = ...) -> Client: ...
357
358class DeviceArrayBase: ...
359
360class DeviceArray(DeviceArrayBase):
361  __array_priority__: int
362  _device: Optional[Device]
363  aval: Any
364  weak_type: Optional[bool]
365  _lazy_expr: Any
366  @property
367  def device_buffer(self: _T) -> _T: ...
368  shape: Tuple[int, ...]
369  dtype: np.dtype
370  size: int
371  ndim: int
372  _value: np.ndarray
373  def copy_to_device(self, dst_device: Device) -> DeviceArray: ...
374  def on_device_size_in_bytes(self) -> int: ...
375  def delete(self) -> None: ...
376  def block_until_ready(self) -> DeviceArray: ...
377  def copy_to_host_async(self) -> _Status: ...
378  def to_py(self) -> np.ndarray: ...
379  def xla_shape(self) -> Shape: ...
380  def xla_dynamic_shape(self) -> Shape: ...
381  client: Client
382  def device(self) -> Device: ...
383  def platform(self) -> str: ...
384  def is_deleted(self) -> bool: ...
385  def unsafe_buffer_pointer(self) -> Any: ...
386  __cuda_array_interface__: Dict[str, Any]
387  traceback: Traceback
388  def clone(self) -> DeviceArray: ...
389
390PyLocalBuffer = DeviceArray
391Buffer = DeviceArray
392
393class Executable:
394  client: Client
395  def local_logical_device_ids(self) -> List[Tuple[int, int]]: ...
396  def local_devices(self) -> List[Device]: ...
397  def size_of_generated_code_in_bytes(self) -> int: ...
398  def delete(self) -> None: ...
399  def execute(self, arguments: Sequence[DeviceArray]) -> List[DeviceArray]: ...
400  def execute_sharded_on_local_devices(
401      self,
402      arguments: Sequence[List[DeviceArray]]) -> List[List[DeviceArray]]: ...
403  def hlo_modules(self) -> List[HloModule]: ...
404  def keep_alive(self) -> None: ...
405  traceback: Traceback
406  fingerprint: Optional[bytes]
407
408def buffer_to_dlpack_managed_tensor(
409    buffer: Buffer,
410    take_ownership: bool = ...) -> Any: ...
411def dlpack_managed_tensor_to_buffer(
412    tensor: Any, cpu_backend: Optional[Client] = ...,
413    gpu_backend: Optional[Client] = ...) -> Buffer: ...
414
415# === BEGIN py_traceback.cc
416
417class Frame:
418  file_name: str
419  function_name: str
420  function_line_start: int
421  line_num: int
422  def __repr__(self) -> str: ...
423
424class Traceback:
425  enabled: ClassVar[bool]
426  @staticmethod
427  def get_traceback() -> Traceback: ...
428  frames: Sequence[Frame]
429  def __str__(self) -> str: ...
430  def as_python_traceback(self) -> Any: ...
431
432def replace_thread_exc_traceback(traceback: Any): ...
433
434# === END py_traceback.cc
435
436class DistributedRuntimeService: ...
437class DistributedRuntimeClient:
438  def connect(self) -> _Status: ...
439  def shutdown(self) -> _Status: ...
440
441def get_distributed_runtime_service(
442    address: str,
443    num_nodes: int,
444    heartbeat_interval: Optional[int],
445    max_missing_heartbeats: Optional[int],
446    enumerate_devices_timeout: Optional[int],
447    shutdown_timeout: Optional[int]) -> DistributedRuntimeService: ...
448def get_distributed_runtime_client(
449    address: str,
450    node_id: int,
451    rpc_timeout: Optional[int],
452    init_timeout: Optional[int],
453    shutdown_timeout: Optional[int],
454    heartbeat_interval: Optional[int],
455    max_missing_heartbeats: Optional[int],
456    missed_heartbeat_callback: Optional[Any],
457    shutdown_on_destruction: Optional[bool]) -> DistributedRuntimeClient: ...
458
459def collect_garbage() -> None: ...
460
461def is_optimized_build() -> bool: ...
462
463
464class CompiledFunctionCache:
465  def __init__(self, capacity: int = ...): ...
466  def __getstate__(self) -> Any: ...
467  def __setstate__(self, Any): ...
468  def size(self) -> int: ...
469  def capacity(self) -> int: ...
470  def clear(self): ...
471
472class CompiledFunction:
473  def __call__(self, *args, **kwargs) -> Any: ...
474  def __getstate__(self) -> Any: ...
475  def __setstate__(self, Any): ...
476  __signature__: inspect.Signature
477  def _cache_size(self) -> int: ...
478  def _clear_cache(self) -> None: ...
479