1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Serializes an Environment into a batch file.""" 15 16# Disable super() warnings since this file must be Python 2 compatible. 17# pylint: disable=super-with-arguments 18 19# goto label written to the end of Windows batch files for exiting a script. 20_SCRIPT_END_LABEL = '_pw_end' 21 22 23class BatchVisitor(object): # pylint: disable=useless-object-inheritance 24 """Serializes an Environment into a batch file.""" 25 26 def __init__(self, *args, **kwargs): 27 pathsep = kwargs.pop('pathsep', ':') 28 super(BatchVisitor, self).__init__(*args, **kwargs) 29 self._replacements = () 30 self._outs = None 31 self._pathsep = pathsep 32 33 def serialize(self, env, outs): 34 """Write a batch file based on the given environment. 35 36 Args: 37 env (environment.Environment): Environment variables to use. 38 outs (file): Batch file to write. 39 """ 40 try: 41 self._replacements = tuple( 42 (key, env.get(key) if value is None else value) 43 for key, value in env.replacements 44 ) 45 self._outs = outs 46 self._outs.write('@echo off\n') 47 48 env.accept(self) 49 50 outs.write(':{}\n'.format(_SCRIPT_END_LABEL)) 51 52 finally: 53 self._replacements = () 54 self._outs = None 55 56 def _apply_replacements(self, action): 57 value = action.value 58 for var, replacement in self._replacements: 59 if var != action.name: 60 value = value.replace(replacement, '%{}%'.format(var)) 61 return value 62 63 def visit_set(self, set): # pylint: disable=redefined-builtin 64 value = self._apply_replacements(set) 65 self._outs.write( 66 'set {name}={value}\n'.format(name=set.name, value=value) 67 ) 68 69 def visit_clear(self, clear): 70 self._outs.write('set {name}=\n'.format(name=clear.name)) 71 72 def visit_remove(self, remove): 73 pass # Not supported on Windows. 74 75 def _join(self, *args): 76 if len(args) == 1 and isinstance(args[0], (list, tuple)): 77 args = args[0] 78 return self._pathsep.join(args) 79 80 def visit_prepend(self, prepend): 81 value = self._apply_replacements(prepend) 82 value = self._join(value, '%{}%'.format(prepend.name)) 83 self._outs.write( 84 'set {name}={value}\n'.format(name=prepend.name, value=value) 85 ) 86 87 def visit_append(self, append): 88 value = self._apply_replacements(append) 89 value = self._join('%{}%'.format(append.name), value) 90 self._outs.write( 91 'set {name}={value}\n'.format(name=append.name, value=value) 92 ) 93 94 def visit_echo(self, echo): 95 if echo.newline: 96 if not echo.value: 97 self._outs.write('echo.\n') 98 else: 99 self._outs.write('echo {}\n'.format(echo.value)) 100 else: 101 self._outs.write('<nul set /p="{}"\n'.format(echo.value)) 102 103 def visit_comment(self, comment): 104 for line in comment.value.splitlines(): 105 self._outs.write(':: {}\n'.format(line)) 106 107 def visit_command(self, command): 108 # TODO(mohrr) use shlex.quote here? 109 self._outs.write('{}\n'.format(' '.join(command.command))) 110 if not command.exit_on_error: 111 return 112 113 # Assume failing command produced relevant output. 114 self._outs.write( 115 'if %ERRORLEVEL% neq 0 goto {}\n'.format(_SCRIPT_END_LABEL) 116 ) 117 118 def visit_doctor(self, doctor): 119 self._outs.write('if "%PW_ACTIVATE_SKIP_CHECKS%"=="" (\n') 120 self.visit_command(doctor) 121 self._outs.write(') else (\n') 122 self._outs.write( 123 'echo Skipping environment check because ' 124 'PW_ACTIVATE_SKIP_CHECKS is set\n' 125 ) 126 self._outs.write(')\n') 127 128 def visit_blank_line(self, blank_line): 129 del blank_line 130 self._outs.write('\n') 131 132 def visit_function(self, function): 133 pass # Not supported on Windows. 134 135 def visit_hash(self, hash): # pylint: disable=redefined-builtin 136 pass # Not relevant on Windows. 137