• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""A module to support operations on ipynb files"""
16
17import collections
18import copy
19import json
20import re
21import shutil
22import tempfile
23
24CodeLine = collections.namedtuple("CodeLine", ["cell_number", "code"])
25
26def is_python(cell):
27  """Checks if the cell consists of Python code."""
28  return (cell["cell_type"] == "code"  # code cells only
29          and cell["source"]  # non-empty cells
30          and not cell["source"][0].startswith("%%"))  # multiline eg: %%bash
31
32
33def process_file(in_filename, out_filename, upgrader):
34  """The function where we inject the support for ipynb upgrade."""
35  print("Extracting code lines from original notebook")
36  raw_code, notebook = _get_code(in_filename)
37  raw_lines = [cl.code for cl in raw_code]
38
39  # The function follows the original flow from `upgrader.process_fil`
40  with tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
41
42    processed_file, new_file_content, log, process_errors = (
43        upgrader.update_string_pasta("\n".join(raw_lines), in_filename))
44
45    if temp_file and processed_file:
46      new_notebook = _update_notebook(notebook, raw_code,
47                                      new_file_content.split("\n"))
48      json.dump(new_notebook, temp_file)
49    else:
50      raise SyntaxError(
51          "Was not able to process the file: \n%s\n" % "".join(log))
52
53    files_processed = processed_file
54    report_text = upgrader._format_log(log, in_filename, out_filename)
55    errors = process_errors
56
57  shutil.move(temp_file.name, out_filename)
58
59  return files_processed, report_text, errors
60
61
62def skip_magic(code_line, magic_list):
63  """Checks if the cell has magic, that is not Python-based.
64
65  Args:
66      code_line: A line of Python code
67      magic_list: A list of jupyter "magic" exceptions
68
69  Returns:
70    If the line jupyter "magic" line, not Python line
71
72   >>> skip_magic('!ls -laF', ['%', '!', '?'])
73  True
74  """
75
76  for magic in magic_list:
77    if code_line.startswith(magic):
78      return True
79
80  return False
81
82
83def check_line_split(code_line):
84  r"""Checks if a line was split with `\`.
85
86  Args:
87      code_line: A line of Python code
88
89  Returns:
90    If the line was split with `\`
91
92  >>> skip_magic("!gcloud ml-engine models create ${MODEL} \\\n")
93  True
94  """
95
96  return re.search(r"\\\s*\n$", code_line)
97
98
99def _get_code(input_file):
100  """Loads the ipynb file and returns a list of CodeLines."""
101
102  raw_code = []
103
104  with open(input_file) as in_file:
105    notebook = json.load(in_file)
106
107  cell_index = 0
108  for cell in notebook["cells"]:
109    if is_python(cell):
110      cell_lines = cell["source"]
111
112      is_line_split = False
113      for line_idx, code_line in enumerate(cell_lines):
114
115        # Sometimes, jupyter has more than python code
116        # Idea is to comment these lines, for upgrade time
117        if skip_magic(code_line, ["%", "!", "?"]) or is_line_split:
118          # Found a special character, need to "encode"
119          code_line = "###!!!" + code_line
120
121          # if this cell ends with `\` -> skip the next line
122          is_line_split = check_line_split(code_line)
123
124        if is_line_split:
125          is_line_split = check_line_split(code_line)
126
127        # Sometimes, people leave \n at the end of cell
128        # in order to migrate only related things, and make the diff
129        # the smallest -> here is another hack
130        if (line_idx == len(cell_lines) - 1) and code_line.endswith("\n"):
131          code_line = code_line.replace("\n", "###===")
132
133        # sometimes a line would start with `\n` and content after
134        # that's the hack for this
135        raw_code.append(
136            CodeLine(cell_index,
137                     code_line.rstrip().replace("\n", "###===")))
138
139      cell_index += 1
140
141  return raw_code, notebook
142
143
144def _update_notebook(original_notebook, original_raw_lines, updated_code_lines):
145  """Updates notebook, once migration is done."""
146
147  new_notebook = copy.deepcopy(original_notebook)
148
149  # validate that the number of lines is the same
150  assert len(original_raw_lines) == len(updated_code_lines), \
151    ("The lengths of input and converted files are not the same: "
152     "{} vs {}".format(len(original_raw_lines), len(updated_code_lines)))
153
154  code_cell_idx = 0
155  for cell in new_notebook["cells"]:
156    if not is_python(cell):
157      continue
158
159    applicable_lines = [
160        idx for idx, code_line in enumerate(original_raw_lines)
161        if code_line.cell_number == code_cell_idx
162    ]
163
164    new_code = [updated_code_lines[idx] for idx in applicable_lines]
165
166    cell["source"] = "\n".join(new_code).replace("###!!!", "").replace(
167        "###===", "\n")
168    code_cell_idx += 1
169
170  return new_notebook
171