• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef HLO_OPS_BASE_STRUCTS
17#define HLO_OPS_BASE_STRUCTS
18
19//===----------------------------------------------------------------------===//
20// Dot dimensions enum definitions.
21//===----------------------------------------------------------------------===//
22
23def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [
24                StructFieldAttr<"lhs_batching_dimensions",   I64ElementsAttr>,
25                StructFieldAttr<"rhs_batching_dimensions",   I64ElementsAttr>,
26                StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
27                StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
28  ]> {
29  let summary = "Structure of dimension information for dot product";
30}
31
32def ScatterDimensionNumbers : StructAttr<
33    "ScatterDimensionNumbers", HLO_Dialect, [
34      StructFieldAttr<"update_window_dims", I64ElementsAttr>,
35      StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
36      StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
37      StructFieldAttr<"index_vector_dim", I64Attr>]> {
38  let summary = "Structure of dimension information for scatter";
39}
40
41def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
42  StructFieldAttr<"input_batch_dimension",I64Attr>,
43  StructFieldAttr<"input_feature_dimension", I64Attr>,
44  StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
45  StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
46  StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
47  StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
48  StructFieldAttr<"output_batch_dimension", I64Attr>,
49  StructFieldAttr<"output_feature_dimension", I64Attr>,
50  StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
51
52  let summary = "Structure of dimension information for conv op";
53}
54
55def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
56      [StructFieldAttr<"offset_dims", I64ElementsAttr>,
57      StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
58      StructFieldAttr<"start_index_map", I64ElementsAttr>,
59      StructFieldAttr<"index_vector_dim", I64Attr>]> {
60  let summary = "Structure of dimension information for gather";
61}
62
63
64// Represents a unique identifier for each Send/Recv instruction pair or
65// optionally for collective instructions (AllReduce, CollectivePermute,
66// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
67def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [
68                StructFieldAttr<"handle", I64Attr>,
69                StructFieldAttr<"type", I64Attr>]> {
70  let summary = "two 64-bit integers 'handle' and 'type'";
71}
72
73#endif // HLO_OPS_BASE_STRUCTS
74