1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""A module to support operations on ipynb files""" 16 17import collections 18import copy 19import json 20import re 21import shutil 22import tempfile 23 24CodeLine = collections.namedtuple("CodeLine", ["cell_number", "code"]) 25 26def is_python(cell): 27 """Checks if the cell consists of Python code.""" 28 return (cell["cell_type"] == "code" # code cells only 29 and cell["source"] # non-empty cells 30 and not cell["source"][0].startswith("%%")) # multiline eg: %%bash 31 32 33def process_file(in_filename, out_filename, upgrader): 34 """The function where we inject the support for ipynb upgrade.""" 35 print("Extracting code lines from original notebook") 36 raw_code, notebook = _get_code(in_filename) 37 raw_lines = [cl.code for cl in raw_code] 38 39 # The function follows the original flow from `upgrader.process_fil` 40 with tempfile.NamedTemporaryFile("w", delete=False) as temp_file: 41 42 processed_file, new_file_content, log, process_errors = ( 43 upgrader.update_string_pasta("\n".join(raw_lines), in_filename)) 44 45 if temp_file and processed_file: 46 new_notebook = _update_notebook(notebook, raw_code, 47 new_file_content.split("\n")) 48 json.dump(new_notebook, temp_file) 49 else: 50 raise SyntaxError( 51 "Was not able to process the file: \n%s\n" % "".join(log)) 52 53 files_processed = processed_file 54 report_text = upgrader._format_log(log, in_filename, out_filename) 55 errors = process_errors 56 57 shutil.move(temp_file.name, out_filename) 58 59 return files_processed, report_text, errors 60 61 62def skip_magic(code_line, magic_list): 63 """Checks if the cell has magic, that is not Python-based. 64 65 Args: 66 code_line: A line of Python code 67 magic_list: A list of jupyter "magic" exceptions 68 69 Returns: 70 If the line jupyter "magic" line, not Python line 71 72 >>> skip_magic('!ls -laF', ['%', '!', '?']) 73 True 74 """ 75 76 for magic in magic_list: 77 if code_line.startswith(magic): 78 return True 79 80 return False 81 82 83def check_line_split(code_line): 84 r"""Checks if a line was split with `\`. 85 86 Args: 87 code_line: A line of Python code 88 89 Returns: 90 If the line was split with `\` 91 92 >>> skip_magic("!gcloud ml-engine models create ${MODEL} \\\n") 93 True 94 """ 95 96 return re.search(r"\\\s*\n$", code_line) 97 98 99def _get_code(input_file): 100 """Loads the ipynb file and returns a list of CodeLines.""" 101 102 raw_code = [] 103 104 with open(input_file) as in_file: 105 notebook = json.load(in_file) 106 107 cell_index = 0 108 for cell in notebook["cells"]: 109 if is_python(cell): 110 cell_lines = cell["source"] 111 112 is_line_split = False 113 for line_idx, code_line in enumerate(cell_lines): 114 115 # Sometimes, jupyter has more than python code 116 # Idea is to comment these lines, for upgrade time 117 if skip_magic(code_line, ["%", "!", "?"]) or is_line_split: 118 # Found a special character, need to "encode" 119 code_line = "###!!!" + code_line 120 121 # if this cell ends with `\` -> skip the next line 122 is_line_split = check_line_split(code_line) 123 124 if is_line_split: 125 is_line_split = check_line_split(code_line) 126 127 # Sometimes, people leave \n at the end of cell 128 # in order to migrate only related things, and make the diff 129 # the smallest -> here is another hack 130 if (line_idx == len(cell_lines) - 1) and code_line.endswith("\n"): 131 code_line = code_line.replace("\n", "###===") 132 133 # sometimes a line would start with `\n` and content after 134 # that's the hack for this 135 raw_code.append( 136 CodeLine(cell_index, 137 code_line.rstrip().replace("\n", "###==="))) 138 139 cell_index += 1 140 141 return raw_code, notebook 142 143 144def _update_notebook(original_notebook, original_raw_lines, updated_code_lines): 145 """Updates notebook, once migration is done.""" 146 147 new_notebook = copy.deepcopy(original_notebook) 148 149 # validate that the number of lines is the same 150 assert len(original_raw_lines) == len(updated_code_lines), \ 151 ("The lengths of input and converted files are not the same: " 152 "{} vs {}".format(len(original_raw_lines), len(updated_code_lines))) 153 154 code_cell_idx = 0 155 for cell in new_notebook["cells"]: 156 if not is_python(cell): 157 continue 158 159 applicable_lines = [ 160 idx for idx, code_line in enumerate(original_raw_lines) 161 if code_line.cell_number == code_cell_idx 162 ] 163 164 new_code = [updated_code_lines[idx] for idx in applicable_lines] 165 166 cell["source"] = "\n".join(new_code).replace("###!!!", "").replace( 167 "###===", "\n") 168 code_cell_idx += 1 169 170 return new_notebook 171