• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 #include "tensorflow/core/util/autotune_maps/conv_parameters.h"
18 
19 #include <vector>
20 
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/core/platform/hash.h"
23 #include "tensorflow/core/util/autotune_maps/autotune_maps_utils.h"
24 #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h"
25 
26 namespace tensorflow {
27 
28 namespace {
29 using ::tensorflow::protobuf::util::MessageDifferencer;
30 
ComputeHash(int device_id,const ConvParametersProto & proto)31 uint64 ComputeHash(int device_id, const ConvParametersProto& proto) {
32   return Hash64Combine(device_id, autotune_maps_utils::HashProto(proto));
33 }
34 }  // namespace
35 
ConvParameters(int64_t batch,int64_t in_depths,const absl::Span<const int64_t> in,int data_format,int64_t out_depths,const absl::Span<const int64_t> filter,const absl::Span<const int64_t> dilation,const absl::Span<const int64_t> stride,const absl::Span<const int64_t> padding,DataType dtype,int device_id,int group_count,absl::optional<ConvParameters::FusionInfo> fusion_info,int version)36 ConvParameters::ConvParameters(
37     int64_t batch, int64_t in_depths, const absl::Span<const int64_t> in,
38     int data_format, int64_t out_depths, const absl::Span<const int64_t> filter,
39     const absl::Span<const int64_t> dilation,
40     const absl::Span<const int64_t> stride,
41     const absl::Span<const int64_t> padding, DataType dtype, int device_id,
42     int group_count, absl::optional<ConvParameters::FusionInfo> fusion_info,
43     int version)
44     : device_id_(device_id) {
45   proto_.set_batch(batch);
46   proto_.set_in_depths(in_depths);
47   *proto_.mutable_in() = {in.begin(), in.end()};
48   proto_.set_data_format(static_cast<int>(data_format));
49   proto_.set_out_depths(out_depths);
50   *proto_.mutable_filter() = {filter.begin(), filter.end()};
51   *proto_.mutable_dilation() = {dilation.begin(), dilation.end()};
52   *proto_.mutable_stride() = {stride.begin(), stride.end()};
53   *proto_.mutable_padding() = {padding.begin(), padding.end()};
54   proto_.set_dtype(dtype);
55   proto_.set_group_count(group_count);
56   if (fusion_info.has_value()) {
57     ConvParametersProto::Fusion fusion_proto;
58     fusion_proto.set_conv_scale(fusion_info.value().conv_scale);
59     fusion_proto.set_side_input_scale(fusion_info.value().side_input_scale);
60     fusion_proto.set_activation_mode(fusion_info.value().activation_mode);
61     fusion_proto.set_is_contrib(fusion_info.value().is_contrib);
62     *proto_.mutable_fusion() = fusion_proto;
63   }
64   proto_.set_device_identifier(
65       autotune_maps_utils::DeviceIdToIdentifier(device_id));
66   proto_.set_version(version);
67   hash_code_ = ComputeHash(device_id_, proto_);
68 }
69 
ConvParameters(int device_id,const ConvParametersProto & proto)70 ConvParameters::ConvParameters(int device_id, const ConvParametersProto& proto)
71     : device_id_(device_id),
72       proto_(proto),
73       hash_code_(ComputeHash(device_id, proto_)) {}
74 
operator ==(const ConvParameters & other) const75 bool ConvParameters::operator==(const ConvParameters& other) const {
76   return device_id_ == other.device_id_ &&
77          MessageDifferencer::Equals(this->proto_, other.proto_);
78 }
79 
ToString() const80 string ConvParameters::ToString() const { return proto_.DebugString(); }
81 
82 }  // namespace tensorflow
83 
84 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
85