1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import enum 17import inspect 18from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union, overload 19 20import numpy as np 21 22from . import ops 23from . import jax_jit 24from . import outfeed_receiver 25from . import pmap_lib 26from . import profiler 27from . import pytree 28 29_LiteralSlice = Any 30_Status = Any 31_Dtype = Any 32_XlaOpMetadata = Any 33 34_T = TypeVar("_T") 35 36class PrimitiveType(enum.IntEnum): 37 PRIMITIVE_TYPE_INVALID: PrimitiveType 38 PRED: PrimitiveType 39 S8: PrimitiveType 40 S16: PrimitiveType 41 S32: PrimitiveType 42 S64: PrimitiveType 43 U8: PrimitiveType 44 U16: PrimitiveType 45 U32: PrimitiveType 46 U64: PrimitiveType 47 BF16: PrimitiveType 48 F16: PrimitiveType 49 F32: PrimitiveType 50 F64: PrimitiveType 51 C64: PrimitiveType 52 C128: PrimitiveType 53 TUPLE: PrimitiveType 54 OPAQUE_TYPE: PrimitiveType 55 TOKEN: PrimitiveType 56 57def bfloat16_dtype() -> Type[Any]: ... 58 59# === BEGIN xla_compiler.cc 60 61class Shape: 62 def __init__(self, s: str): ... 63 @staticmethod 64 def tuple_shape(shapes: Sequence[Shape]) -> Shape: ... 65 @staticmethod 66 def array_shape( 67 type: Union[np.dtype, PrimitiveType], 68 dims_seq: Any = ..., 69 layout_seq: Any = ..., 70 dynamic_dimensions: Optional[List[bool]] = ...) -> Shape: ... 71 @staticmethod 72 def token_shape() -> Shape: ... 73 @staticmethod 74 def scalar_shape(type: Union[np.dtype, PrimitiveType]) -> Shape: ... 75 def dimensions(self) -> Tuple[int, ...]: ... 76 def xla_element_type(self) -> PrimitiveType: ... 77 def element_type(self) -> np.dtype: ... 78 def numpy_dtype(self) -> np.dtype: ... 79 def is_tuple(self) -> bool: ... 80 def is_array(self) -> bool: ... 81 def is_token(self) -> bool: ... 82 def is_static(self) -> bool: ... 83 def is_dynamic(self) -> bool: ... 84 def is_dynamic_dimension(self, dimension: int) -> bool: ... 85 def set_dynamic_dimension(self, dimension: int, is_dynamic: bool) -> None: ... 86 def rank(self) -> int: ... 87 def to_serialized_proto(self) -> bytes: ... 88 def tuple_shapes(self) -> List[Shape]: ... 89 def leaf_count(self) -> int: ... 90 def with_major_to_minor_layout_if_absent(self) -> Shape: ... 91 def __eq__(self, other: Shape) -> bool: ... 92 def __ne__(self, other: Shape) -> bool: ... 93 def __hash__(self) -> int: ... 94 def __repr__(self) -> str: ... 95 96class ProgramShape: 97 def __init__(self, params: Sequence[Shape], result: Shape) -> None: ... 98 def parameter_shapes(self) -> List[Shape]: ... 99 def result_shape(self) -> Shape: ... 100 def __repr__(self) -> str: ... 101 102class ShapeIndex: 103 def __init__(self, indices: List[int]) -> ShapeIndex: ... 104 def __eq__(self, other: Shape) -> bool: ... 105 def __ne__(self, other: Shape) -> bool: ... 106 def __hash__(self) -> int: ... 107 def __repr__(self) -> str: ... 108 109class Literal: 110 def __repr__(self) -> str: ... 111 112class XlaComputation: 113 def __init__(self, serialized_hlo_module_proto: bytes) -> None: ... 114 def get_hlo_module(self) -> HloModule: ... 115 def program_shape(self) -> ProgramShape: ... 116 def as_serialized_hlo_module_proto(self) -> bytes: ... 117 def as_hlo_text(self) -> str: ... 118 def as_hlo_dot_graph(self) -> str: ... 119 def hash(self) -> int: ... 120 def as_hlo_module(elf) -> HloModule: ... 121 122class HloPrintOptions: 123 def __init__(self) -> None: ... 124 @staticmethod 125 def short_parsable() -> HloPrintOptions: ... 126 @staticmethod 127 def canonical() -> HloPrintOptions: ... 128 @staticmethod 129 def fingerprint() -> HloPrintOptions: ... 130 print_large_constants: bool 131 print_metadata: bool 132 print_backend_config: bool 133 print_result_shape: bool 134 print_operand_shape: bool 135 print_operand_names: bool 136 print_ids: bool 137 print_extra_attributes: bool 138 print_program_shape: bool 139 print_percent: bool 140 print_control_dependencies: bool 141 compact_operands: bool 142 include_layout_in_shapes: bool 143 canonicalize_instruction_names: bool 144 canonicalize_computations: bool 145 indent_amount: int 146 is_in_nested_computation: bool 147 leading_and_trailing_instructions_number: int 148 149class HloModule: 150 def to_string(self, options: HloPrintOptions = ...) -> str: ... 151 152def hlo_module_to_dot_graph(hlo_module: HloModule) -> str: ... 153 154def hlo_module_cost_analysis( 155 client: Client, 156 module: HloModule) -> Dict[str, float]: ... 157 158class XlaOp: ... 159 160class XlaBuilder: 161 def __init__(self, name: str) -> None: ... 162 def Build(self, root: Optional[XlaOp] = ...) -> XlaComputation: ... 163 def GetShape(self, __op: XlaOp) -> Shape: ... 164 build = Build 165 def clear_op_metadata(self) -> None: ... 166 get_shape = GetShape 167 def get_program_shape(self, root: Optional[XlaOp] = ...) -> ProgramShape: ... 168 def is_constant(self, __op: XlaOp) -> bool: ... 169 def set_op_metadata(self, metadata: _XlaOpMetadata) -> None: ... 170 def set_sharding(self, sharding: OpSharding_Type) -> None: ... 171 def clear_sharding(self) -> None: ... 172 def setup_alias( 173 self, 174 __output_index: Sequence[int], 175 __param_number: int, 176 __param_index: Sequence[int]) -> None: ... 177 178class DeviceAssignment: 179 @staticmethod 180 def create(array: np.ndarray) -> DeviceAssignment: ... 181 def replica_count(self) -> int: ... 182 def computation_count(self) -> int: ... 183 def __repr__(self) -> str: ... 184 def serialize(self) -> bytes: ... 185 186class CompileOptions: 187 def __init__(self) -> None: ... 188 argument_layouts: Optional[List[Shape]] 189 parameter_is_tupled_arguments: bool 190 executable_build_options: ExecutableBuildOptions 191 tuple_arguments: bool 192 num_replicas: int 193 num_partitions: int 194 device_assignment: Optional[DeviceAssignment] 195 196def register_custom_call_target(fn_name: str, capsule: Any, platform: str) -> _Status: ... 197 198class DebugOptions: 199 def __repr__(self) -> str: ... 200 xla_cpu_enable_fast_math: bool 201 xla_cpu_fast_math_honor_infs: bool 202 xla_cpu_fast_math_honor_nans: bool 203 xla_cpu_fast_math_honor_division: bool 204 xla_cpu_fast_math_honor_functions: bool 205 xla_gpu_enable_fast_min_max: bool 206 xla_backend_optimization_level: int 207 xla_cpu_enable_xprof_traceme: bool 208 xla_llvm_disable_expensive_passes: bool 209 xla_test_all_input_layouts: bool 210 211class ExecutableBuildOptions: 212 def __init__(self) -> None: ... 213 def __repr__(self) -> str: ... 214 result_layout: Optional[Shape] 215 num_replicas: int 216 num_partitions: int 217 debug_options: DebugOptions 218 device_assignment: Optional[DeviceAssignment] 219 use_spmd_partitioning: bool 220 221class PrecisionConfig_Precision(enum.IntEnum): 222 DEFAULT: int 223 HIGH: int 224 HIGHEST: int 225 226class OpSharding_Type(enum.IntEnum): 227 REPLICATED: int 228 MAXIMAL: int 229 TUPLE: int 230 OTHER: int 231 232class ChannelHandle_ChannelType(enum.IntEnum): 233 CHANNEL_TYPE_INVALID: int 234 DEVICE_TO_DEVICE: int 235 DEVICE_TO_HOST: int 236 HOST_TO_DEVICE: int 237 238class ChannelHandle: 239 type: ChannelHandle_ChannelType 240 handle: int 241 def __repr__(self) -> str: ... 242 243class FftType(enum.IntEnum): 244 FFT: int 245 IFFT: int 246 RFFT: int 247 IRFFT: int 248 249# === END xla_compiler.cc 250 251class Device: 252 id: int 253 host_id: int 254 process_index: int 255 platform: str 256 device_kind: str 257 client: Client 258 def __str__(self) -> str: ... 259 def transfer_to_infeed(self, literal: _LiteralSlice): ... 260 def transfer_from_outfeed(self, shape: Shape): ... 261 def live_buffers(self) -> List[Buffer]: ... 262 263class CpuDevice(Device): 264 def __repr__(self) -> str: ... 265 266class GpuDevice(Device): 267 device_vendor: str 268 def __repr__(self) -> str: ... 269 270class TpuDevice(Device): 271 coords: Tuple[int, ...] 272 core_on_chip: int 273 def __repr__(self) -> str: ... 274 275class _GpuAllocatorKind(enum.IntEnum): 276 DEFAULT: int 277 PLATFORM: int 278 BFC: int 279 CUDA_ASYNC: int 280 281class GpuAllocatorConfig: 282 # TODO(b/194673104): Remove once pytype correctly resolves a nested enum. 283 Kind = _GpuAllocatorKind 284 285 def __init__( 286 self, 287 kind: _GpuAllocatorKind = ..., 288 memory_fraction: float = ..., 289 preallocate: bool = ...) -> None: ... 290 291class HostBufferSemantics(enum.IntEnum): 292 IMMUTABLE_ONLY_DURING_CALL: HostBufferSemantics 293 IMMUTABLE_UNTIL_TRANSFER_COMPLETES: HostBufferSemantics 294 ZERO_COPY: HostBufferSemantics 295 296class Client: 297 platform: str 298 platform_version: str 299 runtime_type: str 300 def device_count(self) -> int: ... 301 def local_device_count(self) -> int: ... 302 def devices(self) -> List[Device]: ... 303 def local_devices(self) -> List[Device]: ... 304 def live_buffers(self) -> List[Buffer]: ... 305 def live_executables(self) -> List[Executable]: ... 306 def host_id(self) -> int: ... 307 def process_index(self) -> int: ... 308 @overload 309 def get_default_device_assignment( 310 self, 311 num_replicas: int, 312 num_partitions: int) -> List[List[Device]]: ... 313 @overload 314 def get_default_device_assignment( 315 self, 316 num_replicas: int) -> List[Device]: ... 317 def create_channel_handle(self) -> ChannelHandle: ... 318 def create_device_to_host_channel_handle(self) -> ChannelHandle: ... 319 def create_host_to_device_channel_handle(self) -> ChannelHandle: ... 320 def buffer_from_pyval( 321 self, 322 argument: Any, 323 device: Device = ..., 324 force_copy: bool = ..., 325 host_buffer_semantics: HostBufferSemantics = ...) -> Buffer: ... 326 def compile( 327 self, 328 computation: XlaComputation, 329 compile_options: CompileOptions = ...) -> Executable: ... 330 def serialize_executable(self, executable: Executable) -> bytes: ... 331 def deserialize_executable( 332 self, serialized: bytes, 333 options: CompileOptions) -> Executable: ... 334 # TODO(skyewm): remove when jax stop providing hlo_module 335 def deserialize_executable( 336 self, serialized: bytes, 337 hlo_module: HloModule, 338 options: CompileOptions) -> Executable: ... 339 def heap_profile(self) -> bytes: ... 340 def defragment(self) -> _Status: ... 341 def emit_python_callback( 342 self, callable: Callable, builder: XlaBuilder, operands: Sequence[XlaOp], 343 results_shapes: Sequence[Shape], 344 operand_layouts: Optional[Sequence[Shape]] = ..., 345 has_side_effects: bool = ...) -> Tuple[XlaOp, Any]: ... 346 347 348def get_cpu_client(asynchronous: bool = ...) -> Client: ... 349def get_tfrt_cpu_client(asynchronous: bool = ...) -> Client: ... 350def get_interpreter_client() -> Client: ... 351def get_gpu_client( 352 asynchronous: bool = ..., 353 allocator_config: GpuAllocatorConfig = ..., 354 distributed_client: Optional[DistributedRuntimeClient] = ..., 355 node_id: int = ...) -> Client:... 356def get_tpu_client(max_inflight_computations: int = ...) -> Client: ... 357 358class DeviceArrayBase: ... 359 360class DeviceArray(DeviceArrayBase): 361 __array_priority__: int 362 _device: Optional[Device] 363 aval: Any 364 weak_type: Optional[bool] 365 _lazy_expr: Any 366 @property 367 def device_buffer(self: _T) -> _T: ... 368 shape: Tuple[int, ...] 369 dtype: np.dtype 370 size: int 371 ndim: int 372 _value: np.ndarray 373 def copy_to_device(self, dst_device: Device) -> DeviceArray: ... 374 def on_device_size_in_bytes(self) -> int: ... 375 def delete(self) -> None: ... 376 def block_until_ready(self) -> DeviceArray: ... 377 def copy_to_host_async(self) -> _Status: ... 378 def to_py(self) -> np.ndarray: ... 379 def xla_shape(self) -> Shape: ... 380 def xla_dynamic_shape(self) -> Shape: ... 381 client: Client 382 def device(self) -> Device: ... 383 def platform(self) -> str: ... 384 def is_deleted(self) -> bool: ... 385 def unsafe_buffer_pointer(self) -> Any: ... 386 __cuda_array_interface__: Dict[str, Any] 387 traceback: Traceback 388 def clone(self) -> DeviceArray: ... 389 390PyLocalBuffer = DeviceArray 391Buffer = DeviceArray 392 393class Executable: 394 client: Client 395 def local_logical_device_ids(self) -> List[Tuple[int, int]]: ... 396 def local_devices(self) -> List[Device]: ... 397 def size_of_generated_code_in_bytes(self) -> int: ... 398 def delete(self) -> None: ... 399 def execute(self, arguments: Sequence[DeviceArray]) -> List[DeviceArray]: ... 400 def execute_sharded_on_local_devices( 401 self, 402 arguments: Sequence[List[DeviceArray]]) -> List[List[DeviceArray]]: ... 403 def hlo_modules(self) -> List[HloModule]: ... 404 def keep_alive(self) -> None: ... 405 traceback: Traceback 406 fingerprint: Optional[bytes] 407 408def buffer_to_dlpack_managed_tensor( 409 buffer: Buffer, 410 take_ownership: bool = ...) -> Any: ... 411def dlpack_managed_tensor_to_buffer( 412 tensor: Any, cpu_backend: Optional[Client] = ..., 413 gpu_backend: Optional[Client] = ...) -> Buffer: ... 414 415# === BEGIN py_traceback.cc 416 417class Frame: 418 file_name: str 419 function_name: str 420 function_line_start: int 421 line_num: int 422 def __repr__(self) -> str: ... 423 424class Traceback: 425 enabled: ClassVar[bool] 426 @staticmethod 427 def get_traceback() -> Traceback: ... 428 frames: Sequence[Frame] 429 def __str__(self) -> str: ... 430 def as_python_traceback(self) -> Any: ... 431 432def replace_thread_exc_traceback(traceback: Any): ... 433 434# === END py_traceback.cc 435 436class DistributedRuntimeService: ... 437class DistributedRuntimeClient: 438 def connect(self) -> _Status: ... 439 def shutdown(self) -> _Status: ... 440 441def get_distributed_runtime_service( 442 address: str, 443 num_nodes: int, 444 heartbeat_interval: Optional[int], 445 max_missing_heartbeats: Optional[int], 446 enumerate_devices_timeout: Optional[int], 447 shutdown_timeout: Optional[int]) -> DistributedRuntimeService: ... 448def get_distributed_runtime_client( 449 address: str, 450 node_id: int, 451 rpc_timeout: Optional[int], 452 init_timeout: Optional[int], 453 shutdown_timeout: Optional[int], 454 heartbeat_interval: Optional[int], 455 max_missing_heartbeats: Optional[int], 456 missed_heartbeat_callback: Optional[Any], 457 shutdown_on_destruction: Optional[bool]) -> DistributedRuntimeClient: ... 458 459def collect_garbage() -> None: ... 460 461def is_optimized_build() -> bool: ... 462 463 464class CompiledFunctionCache: 465 def __init__(self, capacity: int = ...): ... 466 def __getstate__(self) -> Any: ... 467 def __setstate__(self, Any): ... 468 def size(self) -> int: ... 469 def capacity(self) -> int: ... 470 def clear(self): ... 471 472class CompiledFunction: 473 def __call__(self, *args, **kwargs) -> Any: ... 474 def __getstate__(self) -> Any: ... 475 def __setstate__(self, Any): ... 476 __signature__: inspect.Signature 477 def _cache_size(self) -> int: ... 478 def _clear_cache(self) -> None: ... 479