Home
last modified time | relevance | path

Searched refs:arg_shapes (Results 1 – 25 of 30) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
Dcompile_mlir_util.cc83 mlir::ModuleOp module, llvm::ArrayRef<TensorOrResourceShape> arg_shapes, in GetXlaInputShapes() argument
105 shape_representation_fn(arg_shapes[i].shape, dtype, in GetXlaInputShapes()
228 Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes, in RefineShapes() argument
239 for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) { in RefineShapes()
244 arg_shapes_copy.reserve(arg_shapes.size()); in RefineShapes()
246 for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) { in RefineShapes()
407 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes, in CompileMlirSetup() argument
410 TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op)); in CompileMlirSetup()
424 llvm::ArrayRef<TensorOrResourceShape> arg_shapes, in BuildHloFromTf() argument
433 CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn)); in BuildHloFromTf()
[all …]
Dtf_xla_mlir_translate.cc119 llvm::SmallVectorImpl<TensorOrResourceShape>& arg_shapes) { in ParseArgumentShapes() argument
120 arg_shapes.clear(); in ParseArgumentShapes()
123 arg_shapes.resize(input_shapes_vector.size()); in ParseArgumentShapes()
127 static_cast<int*>(nullptr), 0, &arg_shapes[shape.index()].shape)); in ParseArgumentShapes()
131 shape.value().getValue(), &arg_shapes[shape.index()].shape)); in ParseArgumentShapes()
241 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes, in CompileMlirToXlaHloViaBuilder() argument
248 TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op)); in CompileMlirToXlaHloViaBuilder()
266 arg_shapes, device_type, in CompileMlirToXlaHloViaBuilder()
288 return PopulateResultIOInfo(module_op, arg_shapes, /*use_tuple_args=*/false, in CompileMlirToXlaHloViaBuilder()
297 llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes; in MlirTfToHloTextTranslateFunctionImpl() local
[all …]
Dcompile_mlir_util.h85 Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
94 llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
103 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
112 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
123 llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
/external/tensorflow/tensorflow/core/kernels/data/
Dmap_defun_op.cc48 const std::vector<TensorShape> arg_shapes; member
62 std::vector<TensorShape> arg_shapes, int64 batch_size, in ComputeOptions()
66 arg_shapes(std::move(arg_shapes)), in ComputeOptions()
111 compute_opts_->arg_shapes.at(index)); in GetArg()
277 std::vector<TensorShape> arg_shapes; in SetupArgs() local
278 arg_shapes.reserve(arguments.size()); in SetupArgs()
281 arg_shapes.push_back(arguments[i].shape()); in SetupArgs()
282 arg_shapes.at(i).RemoveDim(0); in SetupArgs()
286 new ComputeOptions(ctx, arguments, captured_inputs, std::move(arg_shapes), in SetupArgs()
/external/tensorflow/tensorflow/core/tpu/kernels/
Dtpu_compile_op_common.cc180 const std::vector<TensorShape>& arg_shapes, in BuildComputationArgumentDescriptions() argument
193 arg.shape = arg_shapes[i]; in BuildComputationArgumentDescriptions()
256 absl::Span<const TensorShape> arg_shapes, in GetShardingInfo() argument
267 shape_representation_fn(arg_shapes[i], proto_arg.dtype(), in GetShardingInfo()
281 const std::vector<TensorShape>& arg_shapes, in CompileTFFunctionToHlo() argument
301 arg_shapes, guaranteed_constants, *compiler, &args, arg_core_mapping, in CompileTFFunctionToHlo()
336 arg_shape_dims.reserve(arg_shapes.size()); in CompileTFFunctionToHlo()
337 std::vector<PartialTensorShape> partial_arg_shapes(arg_shapes.size()); in CompileTFFunctionToHlo()
338 for (const TensorShape& shape : arg_shapes) { in CompileTFFunctionToHlo()
457 const std::vector<PartialTensorShape>& arg_shapes, Graph* graph, in RunShapeInferenceOnComputation() argument
[all …]
Dtpu_compile_op_common.h104 const std::vector<TensorShape>& arg_shapes,
111 const std::vector<PartialTensorShape>& arg_shapes, Graph* graph,
165 const std::vector<PartialTensorShape>& arg_shapes,
175 const std::vector<TensorShape>& arg_shapes,
187 absl::Span<const TensorShape> arg_shapes,
199 const std::vector<TensorShape>& arg_shapes,
Dtpu_compile_op_impl.cc31 const std::vector<TensorShape>& arg_shapes, in Compile() argument
35 CreateTpuCompilationRequest(computation, metadata_, arg_shapes)); in Compile()
Dtpu_compile_op_support.cc166 std::vector<Shape> arg_shapes; in GetPerDeviceShape() local
174 arg_shapes.push_back( in GetPerDeviceShape()
178 return xla::ShapeUtil::MakeTupleShape(arg_shapes); in GetPerDeviceShape()
340 const std::vector<TensorShape>& arg_shapes) { in CreateTpuCompilationRequest() argument
377 for (const TensorShape& shape : arg_shapes) { in CreateTpuCompilationRequest()
443 std::vector<TensorShape>* arg_shapes) { in ComputeArgumentShapes() argument
444 arg_shapes->resize(metadata.args_size()); in ComputeArgumentShapes()
457 TensorShape& shape = (*arg_shapes)[i]; in ComputeArgumentShapes()
Dtpu_compile_op_support.h145 const std::vector<TensorShape>& arg_shapes);
158 std::vector<TensorShape>* arg_shapes);
Dtpu_compile_op_impl.h57 const std::vector<TensorShape>& arg_shapes,
Dtpu_compile.proto45 repeated TensorShapeProto arg_shapes = 6; field
/external/tensorflow/tensorflow/compiler/xla/pjrt/
Dutils.cc76 std::vector<Shape> arg_shapes; in GetShardedProgramShapes() local
77 arg_shapes.resize(program_shape.parameters_size()); in GetShardedProgramShapes()
90 TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()], in GetShardedProgramShapes()
101 for (int i = 0; i < arg_shapes.size(); ++i) { in GetShardedProgramShapes()
102 if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) { in GetShardedProgramShapes()
109 return std::make_pair(arg_shapes, result_shape); in GetShardedProgramShapes()
/external/tensorflow/tensorflow/compiler/tf2xla/
Dxla_jit_compiled_cpu_function.cc111 std::vector<const xla::Shape*> arg_shapes; in Compile() local
112 arg_shapes.reserve(program_shape->parameters_size()); in Compile()
114 arg_shapes.push_back(&program_shape->parameters(i)); in Compile()
121 client->Compile(computation, arg_shapes, build_options)); in Compile()
/external/tensorflow/tensorflow/compiler/jit/
Dshape_inference.cc46 const std::map<int, InferredShape>& arg_shapes, in PropagateShapes() argument
81 auto it = arg_shapes.find(index); in PropagateShapes()
82 if (it != arg_shapes.end()) { in PropagateShapes()
246 Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes, in InferShapes() argument
263 TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes, in InferShapes()
Dxla_compilation_cache.cc87 for (const auto& a : arg_shapes) { in HumanString()
100 if (arg_shapes != other.arg_shapes) return false; in operator ==()
116 for (const auto& arg : signature.arg_shapes) { in operator ()()
145 signature.arg_shapes.emplace_back(arg.type, in BuildSignature()
Dshape_inference.h43 Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
Dxla_compilation_cache.h104 arg_shapes; member
/external/tensorflow/tensorflow/compiler/xla/service/
Dshape_inference.cc383 absl::Span<const Shape* const> arg_shapes, const int64 dimension) { in InferConcatOpShape() argument
384 if (arg_shapes.empty()) { in InferConcatOpShape()
387 if (dimension < 0 || dimension >= arg_shapes[0]->rank()) { in InferConcatOpShape()
393 for (const Shape* shape : arg_shapes) { in InferConcatOpShape()
434 for (size_t i = 1; i < arg_shapes.size(); ++i) { in InferConcatOpShape()
435 new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension); in InferConcatOpShape()
441 for (const Shape* shape : arg_shapes) { in InferConcatOpShape()
1117 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply, in InferMapShape() argument
1119 if (arg_shapes.empty()) { in InferMapShape()
1124 const Shape* arg_shape = arg_shapes[0]; in InferMapShape()
[all …]
Dshape_inference.h83 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
160 absl::Span<const Shape* const> arg_shapes,
270 absl::Span<const Shape* const> arg_shapes, int64 dimension);
300 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
Dhlo_parser.cc1408 absl::InlinedVector<const Shape*, 2> arg_shapes; in ParseInstructionRhs() local
1409 arg_shapes.reserve(operands.size()); in ParseInstructionRhs()
1411 arg_shapes.push_back(&operand->shape()); in ParseInstructionRhs()
1413 return ShapeInference::InferVariadicOpShape(opcode, arg_shapes); in ParseInstructionRhs()
1555 absl::InlinedVector<const Shape*, 2> arg_shapes; in ParseInstructionRhs() local
1556 arg_shapes.reserve(operands.size()); in ParseInstructionRhs()
1558 arg_shapes.push_back(&operand->shape()); in ParseInstructionRhs()
1561 arg_shapes, to_apply.value()->ComputeProgramShape()); in ParseInstructionRhs()
1770 absl::InlinedVector<const Shape*, 2> arg_shapes; in ParseInstructionRhs() local
1771 arg_shapes.reserve(operands.size()); in ParseInstructionRhs()
[all …]
/external/tensorflow/tensorflow/core/tpu/graph_rewrite/
Ddistributed_tpu_rewrite_pass.h295 const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
307 const std::vector<InferredShape>& arg_shapes,
340 const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
354 const std::vector<InferredShape>& arg_shapes,
433 const std::vector<InferredShape>& arg_shapes,
Ddistributed_tpu_rewrite_pass.cc1694 const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes, in GetArgAndRetvalShapes() argument
1704 arg_shapes->clear(); in GetArgAndRetvalShapes()
1705 arg_shapes->resize(params_info.NumInputsToEachReplica()); in GetArgAndRetvalShapes()
1721 MergeInferredShapes((*arg_shapes)[input_index], *info); in GetArgAndRetvalShapes()
1725 (*arg_shapes)[input_index].shape.DebugString(), " vs. ", in GetArgAndRetvalShapes()
1728 (*arg_shapes)[input_index] = status.ValueOrDie(); in GetArgAndRetvalShapes()
1746 (*arg_shapes)[i].shape = PartialTensorShape(); in GetArgAndRetvalShapes()
1747 (*arg_shapes)[i].handle_shape = PartialTensorShape(); in GetArgAndRetvalShapes()
1756 (*arg_shapes)[i + params_info.NumPerReplicaArgs() + in GetArgAndRetvalShapes()
1769 (*arg_shapes)[i + params_info.NumPerReplicaArgs() + in GetArgAndRetvalShapes()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dshape_inference.h38 ArrayRef<ArrayRef<int64_t>> arg_shapes,
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dfake_quantize_ops.cc99 std::vector<xla::Shape> arg_shapes; in BuildFakeQuantCustomCall() local
104 arg_shapes.push_back(std::move(arg_shape)); in BuildFakeQuantCustomCall()
111 output_shape, arg_shapes); in BuildFakeQuantCustomCall()
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dmlir_hlo_to_hlo.cc513 llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
519 llvm::SmallVectorImpl<xla::Shape>* arg_shapes);
1444 llvm::SmallVectorImpl<xla::Shape>* arg_shapes, in SetEntryTupleShapesAndLeafReplication() argument
1446 arg_shapes->reserve(block->getNumArguments()); in SetEntryTupleShapesAndLeafReplication()
1449 arg_shapes->push_back(xla::TypeToShape(arg.getType())); in SetEntryTupleShapesAndLeafReplication()
1450 xla::Shape& arg_shape = arg_shapes->back(); in SetEntryTupleShapesAndLeafReplication()
1482 llvm::SmallVectorImpl<xla::Shape>* arg_shapes) { in SetEntryTupleShardings() argument
1494 shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]); in SetEntryTupleShardings()
1519 llvm::SmallVector<xla::Shape, 4> arg_shapes; in LowerBasicBlockAsFunction() local
1522 block, entry_args_same_across_replicas, &arg_shapes, in LowerBasicBlockAsFunction()
[all …]

12