• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
17 
18 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
19 #include "tensorflow/core/platform/errors.h"
20 
21 namespace tensorflow {
22 
PopulateTfVersions(mlir::ModuleOp module,const VersionDef & versions)23 void PopulateTfVersions(mlir::ModuleOp module, const VersionDef& versions) {
24   mlir::Builder b(module.getContext());
25   auto producer =
26       b.getNamedAttr("producer", b.getI32IntegerAttr(versions.producer()));
27   auto min_consumer = b.getNamedAttr(
28       "min_consumer", b.getI32IntegerAttr(versions.min_consumer()));
29   auto bad_consumers = b.getNamedAttr(
30       "bad_consumers",
31       b.getI32ArrayAttr(llvm::ArrayRef<int32_t>(
32           versions.bad_consumers().begin(), versions.bad_consumers().end())));
33   module->setAttr("tf.versions",
34                   b.getDictionaryAttr(llvm::ArrayRef<mlir::NamedAttribute>(
35                       {producer, min_consumer, bad_consumers})));
36 }
37 
ExtractTfVersions(mlir::ModuleOp module,VersionDef * versions)38 mlir::LogicalResult ExtractTfVersions(mlir::ModuleOp module,
39                                       VersionDef* versions) {
40   versions->Clear();
41   auto version_attr =
42       module->getAttrOfType<mlir::DictionaryAttr>("tf.versions");
43   if (!version_attr) return mlir::failure();
44 
45   auto producer =
46       version_attr.get("producer").dyn_cast_or_null<mlir::IntegerAttr>();
47   if (!producer) return mlir::failure();
48   versions->set_producer(producer.getInt());
49 
50   auto min_consumer =
51       version_attr.get("min_consumer").dyn_cast_or_null<mlir::IntegerAttr>();
52   if (min_consumer) versions->set_min_consumer(min_consumer.getInt());
53 
54   auto bad_consumers =
55       version_attr.get("bad_consumers").dyn_cast_or_null<mlir::ArrayAttr>();
56   if (!bad_consumers) return mlir::success();
57 
58   for (auto bad_consumer : bad_consumers) {
59     auto bad_consumer_int_attr =
60         bad_consumer.dyn_cast_or_null<mlir::IntegerAttr>();
61     if (!bad_consumer_int_attr) return mlir::failure();
62 
63     versions->mutable_bad_consumers()->Add(bad_consumer_int_attr.getInt());
64   }
65   return mlir::success();
66 }
67 
GetTfGraphProducerVersion(mlir::ModuleOp module)68 ::stream_executor::port::StatusOr<int64_t> GetTfGraphProducerVersion(
69     mlir::ModuleOp module) {
70   auto versions = module->getAttrOfType<::mlir::DictionaryAttr>("tf.versions");
71   if (!versions) {
72     return errors::Internal(
73         "Missing 'tf.versions' attribute on the module, abort.\n");
74   }
75   auto producer = versions.get("producer").dyn_cast<mlir::IntegerAttr>();
76   if (!producer) {
77     return errors::Internal(
78         "Missing 'producer' attribute on the module, abort.\n");
79   }
80   return producer.getInt();
81 }
82 
83 }  // namespace tensorflow
84