• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""A module to support operation on ipynb files"""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import json
24import shutil
25import tempfile
26
27CodeLine = collections.namedtuple("CodeLine", ["cell_number", "code"])
28
29
30def process_file(in_filename, out_filename, upgrader):
31  """The function where we inject the support for ipynb upgrade."""
32  print("Extracting code lines from original notebook")
33  raw_code, notebook = _get_code(in_filename)
34  raw_lines = [cl.code for cl in raw_code]
35
36  # The function follows the original flow from `upgrader.process_fil`
37  with tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
38
39    processed_file, new_file_content, log, process_errors = (
40        upgrader.update_string_pasta("\n".join(raw_lines), in_filename))
41
42    if temp_file and processed_file:
43      new_notebook = _update_notebook(notebook, raw_code,
44                                      new_file_content.split("\n"))
45      json.dump(new_notebook, temp_file)
46    else:
47      raise SyntaxError(
48          "Was not able to process the file: \n%s\n" % "".join(log))
49
50    files_processed = processed_file
51    report_text = upgrader._format_log(log, in_filename, out_filename)
52    errors = process_errors
53
54  shutil.move(temp_file.name, out_filename)
55
56  return files_processed, report_text, errors
57
58
59def _get_code(input_file):
60  """Load the ipynb file and return a list of CodeLines."""
61
62  raw_code = []
63
64  with open(input_file) as in_file:
65    notebook = json.load(in_file)
66
67  cell_index = 0
68  for cell in notebook["cells"]:
69    if cell["cell_type"] == "code":
70      cell_lines = cell["source"]
71
72      for line_idx, code_line in enumerate(cell_lines):
73
74        # Sometimes, jupyter has more than python code
75        # Idea is to comment these lines, for upgrade time
76        if code_line.startswith("%") or code_line.startswith("!") \
77            or code_line.startswith("?"):
78          # Found a special character, need to "encode"
79          code_line = "###!!!" + code_line
80
81        # Sometimes, people leave \n at the end of cell
82        # in order to migrate only related things, and make the diff
83        # the smallest -> here is another hack
84        if (line_idx == len(cell_lines) - 1) and code_line.endswith("\n"):
85          code_line = code_line.replace("\n", "###===")
86
87        # sometimes a line would start with `\n` and content after
88        # that's the hack for this
89        raw_code.append(
90            CodeLine(cell_index,
91                     code_line.rstrip().replace("\n", "###===")))
92
93      cell_index += 1
94
95  return raw_code, notebook
96
97
98def _update_notebook(original_notebook, original_raw_lines, updated_code_lines):
99  """Update notebook, once migration is done."""
100
101  new_notebook = copy.deepcopy(original_notebook)
102
103  # validate that the number of lines is the same
104  assert len(original_raw_lines) == len(updated_code_lines), \
105    ("The lengths of input and converted files are not the same: "
106     "{} vs {}".format(len(original_raw_lines), len(updated_code_lines)))
107
108  code_cell_idx = 0
109  for cell in new_notebook["cells"]:
110    if cell["cell_type"] != "code":
111      continue
112
113    applicable_lines = [
114        idx for idx, code_line in enumerate(original_raw_lines)
115        if code_line.cell_number == code_cell_idx
116    ]
117
118    new_code = [updated_code_lines[idx] for idx in applicable_lines]
119
120    cell["source"] = "\n".join(new_code).replace("###!!!", "").replace(
121        "###===", "\n")
122    code_cell_idx += 1
123
124  return new_notebook
125