/external/tensorflow/tensorflow/compiler/xla/client/lib/ |
D | math.h | 26 XlaOp IsPosInf(XlaOp operand); 27 XlaOp IsNegInf(XlaOp operand); 28 XlaOp IsInf(XlaOp operand); 29 XlaOp IsNan(XlaOp operand); 34 XlaOp IsNegZero(XlaOp operand); 38 XlaOp NextAfter(XlaOp from, XlaOp to); 41 XlaOp Square(XlaOp operand); 44 XlaOp Reciprocal(XlaOp operand); 47 XlaOp Erfc(XlaOp x); 50 XlaOp Erf(XlaOp x); [all …]
|
D | slicing.h | 28 XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices, 34 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start); 38 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start, 43 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, 47 XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts, 50 XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, 51 absl::Span<const XlaOp> starts); 65 XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse = true); 71 XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim, 72 const std::function<XlaOp(XlaOp, XlaOp)>& combiner); [all …]
|
D | math.cc | 33 XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const FP> coefficients) { in EvaluatePolynomial() 36 XlaOp poly = ScalarLike(x, 0.0); in EvaluatePolynomial() 46 XlaOp EvaluateChebyshevPolynomial(XlaOp x, absl::Span<const FP> coefficients) { in EvaluateChebyshevPolynomial() 49 XlaOp b0 = ScalarLike(x, 0.0); in EvaluateChebyshevPolynomial() 50 XlaOp b1 = ScalarLike(x, 0.0); in EvaluateChebyshevPolynomial() 51 XlaOp b2 = ScalarLike(x, 0.0); in EvaluateChebyshevPolynomial() 65 static XlaOp DoWithUpcastToF32(XlaOp operand, in DoWithUpcastToF32() 67 const std::function<XlaOp(XlaOp)>& operation) { in DoWithUpcastToF32() 69 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in DoWithUpcastToF32() 77 XlaOp result = operation(operand); in DoWithUpcastToF32() [all …]
|
D | matrix.h | 33 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); 37 XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0); 47 XlaOp GetMatrixDiagonal(XlaOp x, int k = 0); 48 XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k = 0); 51 XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0); 55 XlaOp TriangleMask(XlaOp x, int diagonal); 58 XlaOp Triangle(XlaOp x, bool lower); 61 XlaOp UpperTriangle(XlaOp x); 64 XlaOp LowerTriangle(XlaOp x); 83 xla::XlaOp BatchDot( [all …]
|
D | prng.h | 28 XlaOp value; 29 XlaOp state; 44 using BitGeneratorTy = std::function<RngOutput(XlaOp key, XlaOp initial_state, 50 RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, 61 RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, 64 std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key); 70 RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state, 72 XlaOp minval, XlaOp maxval, 77 RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state, 78 BitGeneratorTy bit_generator, XlaOp minval, [all …]
|
D | arithmetic.cc | 52 "add", type, builder, [](XlaOp lhs, XlaOp rhs) { return Add(lhs, rhs); }); in CreateScalarAddComputation() 58 "mul", type, builder, [](XlaOp lhs, XlaOp rhs) { return Mul(lhs, rhs); }); in CreateScalarMultiplyComputation() 64 "ge", type, builder, [](XlaOp lhs, XlaOp rhs) { return Ge(lhs, rhs); }); in CreateScalarGeComputation() 70 "max", type, builder, [](XlaOp lhs, XlaOp rhs) { return Max(lhs, rhs); }); in CreateScalarMaxComputation() 76 "min", type, builder, [](XlaOp lhs, XlaOp rhs) { return Min(lhs, rhs); }); in CreateScalarMinComputation() 82 "and", type, builder, [](XlaOp lhs, XlaOp rhs) { return And(lhs, rhs); }); in CreateScalarAndComputation() 88 "or", type, builder, [](XlaOp lhs, XlaOp rhs) { return Or(lhs, rhs); }); in CreateScalarOrComputation() 100 XlaOp Any(XlaOp predicates) { in Any() 102 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in Any() 121 XlaOp lhs_value = in CreateMinMaxComputation() [all …]
|
D | prng.cc | 29 xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, in ConcatScalars() 30 absl::Span<const xla::XlaOp> scalars) { in ConcatScalars() 31 std::vector<xla::XlaOp> vectors; in ConcatScalars() 33 [](xla::XlaOp x) { return xla::Reshape(x, {1}); }); in ConcatScalars() 40 XlaOp RotateLeftU32(XlaOp v, int distance) { in RotateLeftU32() 46 using ThreeFry2x32State = std::array<XlaOp, 2>; 60 std::array<XlaOp, 3> ks; in ThreeFry2x32() 122 std::array<XlaOp, 2> Uint64ToUint32s(XlaOp u64) { in Uint64ToUint32s() 124 XlaOp const32 = ConstantR0WithType(builder, U64, 32); in Uint64ToUint32s() 125 XlaOp fst = ConvertElementType(u64, U32); in Uint64ToUint32s() [all …]
|
D | tridiagonal.cc | 50 StatusOr<int64> CheckSystemAndReturnNumEquations(XlaOp lower_diagonal, in CheckSystemAndReturnNumEquations() 51 XlaOp main_diagonal, in CheckSystemAndReturnNumEquations() 52 XlaOp upper_diagonal, in CheckSystemAndReturnNumEquations() 53 XlaOp rhs) { in CheckSystemAndReturnNumEquations() 111 XlaOp Coefficient(XlaOp operand, int32 i) { in Coefficient() 117 XlaOp Coefficient(XlaOp operand, XlaOp i) { in Coefficient() 122 XlaOp UpdateEq(XlaOp updated, int32 i, XlaOp update) { in UpdateEq() 127 XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) { in UpdateEq() 145 StatusOr<XlaOp> ThomasSolver(XlaOp lower_diagonal, XlaOp main_diagonal, in ThomasSolver() 146 XlaOp upper_diagonal, XlaOp rhs) { in ThomasSolver() [all …]
|
D | slicing.cc | 29 XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices, in DynamicStridedSlice() 32 XlaOp sliced_input = DynamicSlice(input, base_indices, window_sizes); in DynamicStridedSlice() 41 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start, in SliceInMinorDims() 44 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in SliceInMinorDims() 72 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) { in UpdateSlice() 74 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in UpdateSlice() 82 std::vector<XlaOp> start_ops(start.size()); in UpdateSlice() 90 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, in UpdateSliceInMinorDims() 93 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in UpdateSliceInMinorDims() 115 StatusOr<std::vector<XlaOp>> PrependZerosInMajorDims( in PrependZerosInMajorDims() [all …]
|
/external/tensorflow/tensorflow/compiler/xla/client/ |
D | xla_builder.h | 49 class XlaOp; variable 55 static XlaOp BuildFusion(XlaBuilder* builder, 56 absl::Span<const XlaOp> operands, 60 static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand, 63 static HloInstructionProto* GetInstruction(XlaOp op); 71 class XlaOp { 73 XlaOp() : handle_(-1), builder_(nullptr) { in XlaOp() function 74 static_assert(std::is_trivially_destructible<XlaOp>::value, in XlaOp() 77 ~XlaOp() = default; 79 XlaOp(const XlaOp& other) = default; [all …]
|
D | xla_builder.cc | 154 XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder, in BuildFusion() 155 absl::Span<const XlaOp> operands, in BuildFusion() 158 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in BuildFusion() 171 XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand, in BuildBitcast() 173 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in BuildBitcast() 181 HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) { in GetInstruction() 188 XlaOp operator-(XlaOp x) { return Neg(x); } in operator -() 189 XlaOp operator+(XlaOp x, XlaOp y) { return Add(x, y); } in operator +() 190 XlaOp operator-(XlaOp x, XlaOp y) { return Sub(x, y); } in operator -() 191 XlaOp operator*(XlaOp x, XlaOp y) { return Mul(x, y); } in operator *() [all …]
|
/external/tensorflow/tensorflow/compiler/mlir/xla/ir/ |
D | mlir_hlo_builder.h | 71 StatusOr<XlaOp> MakeXlaOp(mlir::Value val); 76 mlir::Value GetValue(XlaOp op) { in GetValue() 84 std::vector<mlir::Value> GetValues(absl::Span<const XlaOp> ops) { in GetValues() 102 StatusOr<const Shape*> GetShapePtr(XlaOp op) const override; 111 XlaOp ConstantLiteral(const LiteralSlice& literal) override; 113 StatusOr<XlaOp> ConvGeneralDilatedInternal( 114 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, 123 StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand, 127 StatusOr<XlaOp> TriangularSolveInternal( 128 const Shape& shape, XlaOp a, XlaOp b, [all …]
|
D | mlir_hlo_builder.cc | 72 StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) { in MakeXlaOp() 81 return XlaOp(handle, this); in MakeXlaOp() 84 XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) { in ConstantLiteral() 85 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in ConstantLiteral() 93 StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal( in ConvGeneralDilatedInternal() 94 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, in ConvGeneralDilatedInternal() 119 StatusOr<XlaOp> MlirHloBuilder::FftInternal( in FftInternal() 120 const Shape& shape, XlaOp operand, FftType fft_type, in FftInternal() 131 StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal( in CustomCallInternal() 132 const string& call_target_name, absl::Span<const XlaOp> operands, in CustomCallInternal() [all …]
|
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
D | tensor_list_utils.h | 29 Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized); 34 Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list); 37 Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, 38 xla::XlaOp* output_list); 43 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape); 48 Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer); 53 Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index); 58 Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, 59 xla::XlaOp* result); 62 xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, [all …]
|
D | fake_quantize_ops.cc | 49 const xla::XlaOp& min, const xla::XlaOp& max, in XlaNudge() 51 xla::XlaOp* nudged_min, xla::XlaOp* nudged_max, in XlaNudge() 52 xla::XlaOp* scale) { in XlaNudge() 56 xla::XlaOp quant_min = in XlaNudge() 58 xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale)); in XlaNudge() 59 xla::XlaOp quant_max = in XlaNudge() 61 xla::XlaOp nudged_zero_point = in XlaNudge() 69 xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, in Quantize() 71 const xla::XlaOp& nudged_input_min, in Quantize() 72 const xla::XlaOp& nudged_input_max, in Quantize() [all …]
|
D | variable_ops.cc | 71 xla::XlaOp handle; in Compile() 97 xla::XlaOp handle; in Compile() 113 xla::XlaOp handle; in Compile() 133 xla::XlaOp input; in Compile() 136 xla::XlaOp gather; in Compile() 151 std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&, in ResourceScatterOp() argument 163 xla::XlaOp var_value; in Compile() 167 const xla::XlaOp indices = context->Input(1); in Compile() 168 const xla::XlaOp updates = context->Input(2); in Compile() 179 const std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&, [all …]
|
D | tensor_list_utils.cc | 115 Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized) { in IsTensorListInitialized() 121 Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) { in IsNestedTensorList() 132 Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, in BuildNonNestedTensorList() 133 xla::XlaOp* output_list) { in BuildNonNestedTensorList() 139 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) { in GetTensorListBufferShape() 150 Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) { in GetTensorListBuffer() 160 Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) { in GetTensorListPushIndex() 172 Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, in SetTensorListPushIndex() 173 xla::XlaOp* result) { in SetTensorListPushIndex() 181 std::vector<xla::XlaOp> result_parts; in SetTensorListPushIndex() [all …]
|
D | image_ops.cc | 38 std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, in RGBToHSV() 39 const std::array<xla::XlaOp, 3>& rgb, in RGBToHSV() argument 71 std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b, in HSVToRGB() 72 const std::array<xla::XlaOp, 3>& hsv, in HSVToRGB() argument 74 xla::XlaOp hue = hsv[0]; in HSVToRGB() 75 xla::XlaOp saturation = hsv[1]; in HSVToRGB() 76 xla::XlaOp value = hsv[2]; in HSVToRGB() 113 xla::XlaOp input = context->Input(0); in Compile() 115 xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, in Compile() 118 xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, in Compile() [all …]
|
D | reduction_ops.cc | 35 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue() 38 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, in BuildReducer() 39 const xla::XlaOp& scalar_rhs) override { in BuildReducer() 53 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue() 57 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, in BuildReducer() 58 const xla::XlaOp& scalar_rhs) override { in BuildReducer() 71 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue() 75 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, in BuildReducer() 76 const xla::XlaOp& scalar_rhs) override { in BuildReducer() 103 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue() [all …]
|
D | stateful_random_ops.cc | 43 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { in BitGen() 46 xla::XlaOp result = in BitGen() 48 xla::XlaOp data = xla::GetTupleElement(result, 1); in BitGen() 49 xla::XlaOp new_state = in BitGen() 54 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { in BitGen() 56 xla::XlaOp result = xla::RngBitGenerator( in BitGen() 58 xla::XlaOp data = xla::GetTupleElement(result, 1); in BitGen() 59 xla::XlaOp new_state = xla::Reshape( in BitGen() 66 xla::RngOutput StatefulRngUniform(Algorithm alg, xla::XlaOp key, in StatefulRngUniform() 67 xla::XlaOp initial_state, in StatefulRngUniform() [all …]
|
D | training_ops.cc | 34 xla::XlaOp handle; in Compile() 59 xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr, in ProximalGradientDescentUpdate() 60 xla::XlaOp l1, xla::XlaOp l2, in ProximalGradientDescentUpdate() 61 xla::XlaOp grad) { in ProximalGradientDescentUpdate() 62 xla::XlaOp one = xla::ScalarLike(lr, 1.0); in ProximalGradientDescentUpdate() 63 xla::XlaOp zero = xla::ScalarLike(lr, 0.0); in ProximalGradientDescentUpdate() 64 xla::XlaOp prox_var = var - grad * lr; in ProximalGradientDescentUpdate() 65 xla::XlaOp l1_gt_zero = in ProximalGradientDescentUpdate() 67 xla::XlaOp l1_le_zero = prox_var; in ProximalGradientDescentUpdate() 80 xla::XlaOp var; in Compile() [all …]
|
D | stateless_random_ops.cc | 45 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { in GetBitGeneratorForDevice() 47 xla::XlaOp philox_state = in GetBitGeneratorForDevice() 49 xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, in GetBitGeneratorForDevice() 55 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { in GetBitGeneratorForDevice() 57 xla::XlaOp result = in GetBitGeneratorForDevice() 66 xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { in MaybeConvertF32ToBF16() 69 xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) & in MaybeConvertF32ToBF16() 78 xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, in StatelessRngUniform() 79 xla::XlaOp seeds, const xla::Shape& shape, in StatelessRngUniform() 80 xla::XlaOp minval, xla::XlaOp maxval) { in StatelessRngUniform() [all …]
|
D | segment_reduction_ops.cc | 36 virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; 39 virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) = 0; 91 std::vector<xla::XlaOp> buffer_dims; in Compile() 115 auto combiner = [this](xla::XlaOp a, xla::XlaOp b, in Compile() 133 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue() 136 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; }; in Combine() 148 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue() 151 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a * b; }; in Combine() 163 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue() 166 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { in Combine() [all …]
|
/external/tensorflow/tensorflow/core/tpu/kernels/ |
D | topk_ops.cc | 42 xla::XlaOp CreateKthOrderStatisticComputation(xla::XlaBuilder* builder, in CreateKthOrderStatisticComputation() 44 const xla::XlaOp input, in CreateKthOrderStatisticComputation() 45 const xla::XlaOp k) { in CreateKthOrderStatisticComputation() 49 xla::XlaOp input_sm32 = xla::BitcastConvertType(input, xla::S32); in CreateKthOrderStatisticComputation() 50 xla::XlaOp zero_r0 = xla::ConstantR0<int32>(builder, 0); in CreateKthOrderStatisticComputation() 51 xla::XlaOp zero_r1 = xla::Broadcast(zero_r0, {height}); in CreateKthOrderStatisticComputation() 52 xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, {height, width}); in CreateKthOrderStatisticComputation() 54 xla::XlaOp max_r0 = xla::ConstantR0<int32>(builder, 0x7FFFFFFF); in CreateKthOrderStatisticComputation() 55 xla::XlaOp max_r1 = xla::Broadcast(max_r0, {height}); in CreateKthOrderStatisticComputation() 58 xla::XlaOp negative_zero_r0 = xla::ConstantR0<int32>(builder, 0x80000000); in CreateKthOrderStatisticComputation() [all …]
|
/external/tensorflow/tensorflow/compiler/tf2xla/lib/ |
D | random.cc | 29 xla::XlaOp TruncatedNormal(xla::XlaOp uniform) { in TruncatedNormal() 42 xla::XlaOp ParameterizedTruncatedNormal(xla::XlaOp uniform, xla::XlaOp mu, in ParameterizedTruncatedNormal() 43 xla::XlaOp sigma, xla::XlaOp a, in ParameterizedTruncatedNormal() 44 xla::XlaOp b) { in ParameterizedTruncatedNormal() 45 xla::XlaOp one = xla::ScalarLike(uniform, 1.0); in ParameterizedTruncatedNormal() 46 xla::XlaOp two = xla::ScalarLike(uniform, 2.0); in ParameterizedTruncatedNormal() 47 xla::XlaOp sqrt_2 = xla::ScalarLike(uniform, std::sqrt(2.0)); in ParameterizedTruncatedNormal() 49 auto normal_cdf = [&](xla::XlaOp x) { in ParameterizedTruncatedNormal() 55 xla::XlaOp alpha = (a - mu) / sigma; in ParameterizedTruncatedNormal() 56 xla::XlaOp beta = (b - mu) / sigma; in ParameterizedTruncatedNormal() [all …]
|