Home
last modified time | relevance | path

Searched refs:shape_representation_fn (Results 1 – 22 of 22) sorted by relevance

/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
Dcompile_mlir_util.cc85 const XlaHelpers::ShapeRepresentationFn shape_representation_fn, in GetXlaInputShapes() argument
105 shape_representation_fn(arg_shapes[i].shape, dtype, in GetXlaInputShapes()
122 shape_representation_fn, &xla_shape)); in GetXlaInputShapes()
138 const XlaHelpers::ShapeRepresentationFn shape_representation_fn, in GetOutputInfo() argument
142 [shape_representation_fn](const TensorShape& shape, DataType dtype) { in GetOutputInfo()
143 return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); in GetOutputInfo()
392 const XlaHelpers::ShapeRepresentationFn shape_representation_fn, in ConvertMLIRToXlaComputation() argument
401 shape_representation_fn)); in ConvertMLIRToXlaComputation()
408 XlaHelpers::ShapeRepresentationFn* shape_representation_fn) { in CompileMlirSetup() argument
415 if (!*shape_representation_fn) in CompileMlirSetup()
[all …]
Dcompile_mlir_util.h74 const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
105 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
115 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
125 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
138 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
150 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
Dtf_xla_mlir_translate.cc286 XlaHelpers::ShapeRepresentationFn shape_representation_fn = in CompileMlirToXlaHloViaBuilder() local
290 shape_representation_fn, compilation_result); in CompileMlirToXlaHloViaBuilder()
/external/tensorflow/tensorflow/compiler/tf2xla/
Dxla_helpers.cc141 XlaHelpers::ShapeRepresentationFn shape_representation_fn, in RewriteLayoutWithShardedShape() argument
171 shape_representation_fn(per_device_tensor_shape, dtype, in RewriteLayoutWithShardedShape()
182 XlaHelpers::ShapeRepresentationFn shape_representation_fn, in ReshapeWithCorrectRepresentationAndSharding() argument
192 shape_representation_fn, subsharding, fast_mem)); in ReshapeWithCorrectRepresentationAndSharding()
203 shape_representation_fn(shape, dtype, fast_mem)); in ReshapeWithCorrectRepresentationAndSharding()
208 hlo_sharding, fast_mem, shape_representation_fn, &to_shape)); in ReshapeWithCorrectRepresentationAndSharding()
Dxla_helpers.h88 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
95 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
Dxla_compiler.cc171 const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, in BuildComputation() argument
231 if (shape_representation_fn) { in BuildComputation()
236 shape_representation_fn, sharding, in BuildComputation()
330 if (shape_representation_fn) { in BuildComputation()
335 shape_representation_fn, sharding, arg.fast_mem)); in BuildComputation()
526 if (!options_.shape_representation_fn) { in XlaCompiler()
527 options_.shape_representation_fn = IdentityShapeRepresentationFn(); in XlaCompiler()
821 options_.shape_representation_fn, result)); in CompileFunction()
849 TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( in XLAShapeForArgument()
854 options_.shape_representation_fn, xla_shape)); in XLAShapeForArgument()
[all …]
Dxla_compiler.h176 ShapeRepresentationFn shape_representation_fn; member
Dxla_compiler_test.cc306 options.shape_representation_fn = in TEST_F()
349 options.shape_representation_fn = in TEST_F()
402 options.shape_representation_fn = in TEST_F()
1098 options.shape_representation_fn = in TEST_F()
1139 options.shape_representation_fn = in TEST_F()
1273 options.shape_representation_fn = in TEST_F()
1344 options.shape_representation_fn = in TEST_F()
Dxla_op_kernel.cc494 ctx->compiler()->options().shape_representation_fn( in ReadVariableInputTensor()
637 ctx->compiler()->options().shape_representation_fn( in AssignVariableTensor()
/external/tensorflow/tensorflow/compiler/jit/
Dxla_device.h63 XlaCompiler::ShapeRepresentationFn shape_representation_fn,
72 const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { in shape_representation_fn() function
128 XlaCompiler::ShapeRepresentationFn shape_representation_fn; member
Dxla_device_context.h60 XlaCompiler::ShapeRepresentationFn shape_representation_fn,
82 const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { in shape_representation_fn() function
Dxla_platform_info.cc148 options.shape_representation_fn = in GenerateCompilerOptions()
149 platform_info.xla_device_metadata()->shape_representation_fn(); in GenerateCompilerOptions()
Dxla_device.cc148 XlaCompiler::ShapeRepresentationFn shape_representation_fn, in Metadata() argument
153 shape_representation_fn_(std::move(shape_representation_fn)), in Metadata()
203 options.shape_representation_fn, in XlaDevice()
213 shape_representation_fn_(options.shape_representation_fn), in XlaDevice()
Dxla_device_context.cc82 XlaCompiler::ShapeRepresentationFn shape_representation_fn, in XlaDeviceContext() argument
90 shape_representation_fn_(std::move(shape_representation_fn)), in XlaDeviceContext()
Dxla_tpu_device.cc201 dst_xla_context->shape_representation_fn()( in TpuDeviceToDeviceCopy()
374 options.shape_representation_fn = &TpuShapeRepresentation; in CreateDevices()
Dxla_compilation_cache.cc308 *options.flib_def, debug_info, options.shape_representation_fn, result); in CompileSingleOp()
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dtype_to_shape.h40 mlir::Type type, CustomShapeRepresentationFn shape_representation_fn);
Dmlir_hlo_to_hlo.h52 shape_representation_fn = nullptr,
Dtype_to_shape.cc84 mlir::Type type, CustomShapeRepresentationFn shape_representation_fn) { in TypeToShape() argument
97 return shape_representation_fn(fully_defined_tensor_shape, dtype); in TypeToShape()
Dmlir_hlo_to_hlo.cc446 tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, in ConvertToHloModule() argument
452 shape_representation_fn_(shape_representation_fn), in ConvertToHloModule()
1769 const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, in ConvertMlirHloToHlo() argument
1779 return_tuple, shape_representation_fn, options); in ConvertMlirHloToHlo()
/external/tensorflow/tensorflow/core/tpu/kernels/
Dtpu_compile_op_common.h174 const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
188 const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
Dtpu_compile_op_common.cc257 const XlaCompiler::ShapeRepresentationFn shape_representation_fn, in GetShardingInfo() argument
267 shape_representation_fn(arg_shapes[i], proto_arg.dtype(), in GetShardingInfo()
271 shape_representation_fn, &xla_arg_shape)); in GetShardingInfo()
280 const XlaCompiler::ShapeRepresentationFn shape_representation_fn, in CompileTFFunctionToHlo() argument
295 compiler_options.shape_representation_fn = shape_representation_fn; in CompileTFFunctionToHlo()