• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Stores the environment changes necessary for Pigweed."""
15
16import contextlib
17import os
18import re
19
20# The order here is important. On Python 2 we want StringIO.StringIO and not
21# io.StringIO. On Python 3 there is no StringIO module so we want io.StringIO.
22# Not using six because six is not a standard package we can expect to have
23# installed in the system Python.
24try:
25    from StringIO import StringIO  # type: ignore
26except ImportError:
27    from io import StringIO
28
29from . import apply_visitor
30from . import batch_visitor
31from . import gni_visitor
32from . import json_visitor
33from . import shell_visitor
34
35# Disable super() warnings since this file must be Python 2 compatible.
36# pylint: disable=super-with-arguments
37
38
39class BadNameType(TypeError):
40    pass
41
42
43class BadValueType(TypeError):
44    pass
45
46
47class EmptyValue(ValueError):
48    pass
49
50
51class NewlineInValue(TypeError):
52    pass
53
54
55class BadVariableName(ValueError):
56    pass
57
58
59class UnexpectedAction(ValueError):
60    pass
61
62
63class AcceptNotOverridden(TypeError):
64    pass
65
66
67class _Action(object):  # pylint: disable=useless-object-inheritance
68    def unapply(self, env, orig_env):
69        pass
70
71    def accept(self, visitor):
72        del visitor
73        raise AcceptNotOverridden('accept() not overridden for {}'.format(
74            self.__class__.__name__))
75
76    def write_deactivate(self,
77                         outs,
78                         windows=(os.name == 'nt'),
79                         replacements=()):
80        pass
81
82
83class _VariableAction(_Action):
84    # pylint: disable=keyword-arg-before-vararg
85    def __init__(self, name, value, allow_empty_values=False, *args, **kwargs):
86        super(_VariableAction, self).__init__(*args, **kwargs)
87        self.name = name
88        self.value = value
89        self.allow_empty_values = allow_empty_values
90
91        self._check()
92
93    def _check(self):
94        try:
95            # In python2, unicode is a distinct type.
96            valid_types = (str, unicode)
97        except NameError:
98            valid_types = (str, )
99
100        if not isinstance(self.name, valid_types):
101            raise BadNameType('variable name {!r} not of type str'.format(
102                self.name))
103        if not isinstance(self.value, valid_types):
104            raise BadValueType('{!r} value {!r} not of type str'.format(
105                self.name, self.value))
106
107        # Empty strings as environment variable values have different behavior
108        # on different operating systems. Just don't allow them.
109        if not self.allow_empty_values and self.value == '':
110            raise EmptyValue('{!r} value {!r} is the empty string'.format(
111                self.name, self.value))
112
113        # Many tools have issues with newlines in environment variable values.
114        # Just don't allow them.
115        if '\n' in self.value:
116            raise NewlineInValue('{!r} value {!r} contains a newline'.format(
117                self.name, self.value))
118
119        if not re.match(r'^[A-Z_][A-Z0-9_]*$', self.name, re.IGNORECASE):
120            raise BadVariableName('bad variable name {!r}'.format(self.name))
121
122    def unapply(self, env, orig_env):
123        if self.name in orig_env:
124            env[self.name] = orig_env[self.name]
125        else:
126            env.pop(self.name, None)
127
128    def __repr__(self):
129        return '{}({}, {})'.format(self.__class__.__name__, self.name,
130                                   self.value)
131
132
133class Set(_VariableAction):
134    """Set a variable."""
135    def __init__(self, *args, **kwargs):
136        deactivate = kwargs.pop('deactivate', True)
137        super(Set, self).__init__(*args, **kwargs)
138        self.deactivate = deactivate
139
140    def accept(self, visitor):
141        visitor.visit_set(self)
142
143
144class Clear(_VariableAction):
145    """Remove a variable from the environment."""
146    def __init__(self, *args, **kwargs):
147        kwargs['value'] = ''
148        kwargs['allow_empty_values'] = True
149        super(Clear, self).__init__(*args, **kwargs)
150
151    def accept(self, visitor):
152        visitor.visit_clear(self)
153
154
155class Remove(_VariableAction):
156    """Remove a value from a PATH-like variable."""
157    def accept(self, visitor):
158        visitor.visit_remove(self)
159
160
161class BadVariableValue(ValueError):
162    pass
163
164
165def _append_prepend_check(action):
166    if '=' in action.value:
167        raise BadVariableValue('"{}" contains "="'.format(action.value))
168
169
170class Prepend(_VariableAction):
171    """Prepend a value to a PATH-like variable."""
172    def __init__(self, name, value, join, *args, **kwargs):
173        super(Prepend, self).__init__(name, value, *args, **kwargs)
174        self._join = join
175
176    def _check(self):
177        super(Prepend, self)._check()
178        _append_prepend_check(self)
179
180    def accept(self, visitor):
181        visitor.visit_prepend(self)
182
183
184class Append(_VariableAction):
185    """Append a value to a PATH-like variable. (Uncommon, see Prepend.)"""
186    def __init__(self, name, value, join, *args, **kwargs):
187        super(Append, self).__init__(name, value, *args, **kwargs)
188        self._join = join
189
190    def _check(self):
191        super(Append, self)._check()
192        _append_prepend_check(self)
193
194    def accept(self, visitor):
195        visitor.visit_append(self)
196
197
198class BadEchoValue(ValueError):
199    pass
200
201
202class Echo(_Action):
203    """Echo a value to the terminal."""
204    def __init__(self, value, newline, *args, **kwargs):
205        # These values act funny on Windows.
206        if value.lower() in ('off', 'on'):
207            raise BadEchoValue(value)
208        super(Echo, self).__init__(*args, **kwargs)
209        self.value = value
210        self.newline = newline
211
212    def accept(self, visitor):
213        visitor.visit_echo(self)
214
215    def __repr__(self):
216        return 'Echo({}, newline={})'.format(self.value, self.newline)
217
218
219class Comment(_Action):
220    """Add a comment to the init script."""
221    def __init__(self, value, *args, **kwargs):
222        super(Comment, self).__init__(*args, **kwargs)
223        self.value = value
224
225    def accept(self, visitor):
226        visitor.visit_comment(self)
227
228    def __repr__(self):
229        return 'Comment({})'.format(self.value)
230
231
232class Command(_Action):
233    """Run a command."""
234    def __init__(self, command, *args, **kwargs):
235        exit_on_error = kwargs.pop('exit_on_error', True)
236        super(Command, self).__init__(*args, **kwargs)
237        assert isinstance(command, (list, tuple))
238        self.command = command
239        self.exit_on_error = exit_on_error
240
241    def accept(self, visitor):
242        visitor.visit_command(self)
243
244    def __repr__(self):
245        return 'Command({})'.format(self.command)
246
247
248class Doctor(Command):
249    def __init__(self, *args, **kwargs):
250        log_level = 'warn' if 'PW_ENVSETUP_QUIET' in os.environ else 'info'
251        super(Doctor, self).__init__(
252            command=['pw', '--no-banner', '--loglevel', log_level, 'doctor'],
253            *args,
254            **kwargs)
255
256    def accept(self, visitor):
257        visitor.visit_doctor(self)
258
259    def __repr__(self):
260        return 'Doctor()'
261
262
263class BlankLine(_Action):
264    """Write a blank line to the init script."""
265    def accept(self, visitor):
266        visitor.visit_blank_line(self)
267
268    def __repr__(self):
269        return 'BlankLine()'
270
271
272class Function(_Action):
273    def __init__(self, name, body, *args, **kwargs):
274        super(Function, self).__init__(*args, **kwargs)
275        self.name = name
276        self.body = body
277
278    def accept(self, visitor):
279        visitor.visit_function(self)
280
281    def __repr__(self):
282        return 'Function({}, {})'.format(self.name, self.body)
283
284
285class Hash(_Action):
286    def accept(self, visitor):
287        visitor.visit_hash(self)
288
289    def __repr__(self):
290        return 'Hash()'
291
292
293class Join(object):  # pylint: disable=useless-object-inheritance
294    def __init__(self, pathsep=os.pathsep):
295        self.pathsep = pathsep
296
297
298# TODO(mohrr) remove disable=useless-object-inheritance once in Python 3.
299# pylint: disable=useless-object-inheritance
300class Environment(object):
301    """Stores the environment changes necessary for Pigweed.
302
303    These changes can be accessed by writing them to a file for bash-like
304    shells to source or by using this as a context manager.
305    """
306    def __init__(self, *args, **kwargs):
307        pathsep = kwargs.pop('pathsep', os.pathsep)
308        windows = kwargs.pop('windows', os.name == 'nt')
309        allcaps = kwargs.pop('allcaps', windows)
310        super(Environment, self).__init__(*args, **kwargs)
311        self._actions = []
312        self._pathsep = pathsep
313        self._windows = windows
314        self._allcaps = allcaps
315        self.replacements = []
316        self._join = Join(pathsep)
317        self._finalized = False
318
319    def add_replacement(self, variable, value=None):
320        self.replacements.append((variable, value))
321
322    def normalize_key(self, name):
323        if self._allcaps:
324            try:
325                return name.upper()
326            except AttributeError:
327                # The _Action class has code to handle incorrect types, so
328                # we just ignore this error here.
329                pass
330        return name
331
332    # A newline is printed after each high-level operation. Top-level
333    # operations should not invoke each other (this is why _remove() exists).
334
335    def set(self, name, value, deactivate=True):
336        """Set a variable."""
337        assert not self._finalized
338        name = self.normalize_key(name)
339        self._actions.append(Set(name, value, deactivate=deactivate))
340        self._blankline()
341
342    def clear(self, name):
343        """Remove a variable."""
344        assert not self._finalized
345        name = self.normalize_key(name)
346        self._actions.append(Clear(name))
347        self._blankline()
348
349    def _remove(self, name, value):
350        """Remove a value from a variable."""
351        assert not self._finalized
352        name = self.normalize_key(name)
353        if self.get(name, None):
354            self._actions.append(Remove(name, value, self._pathsep))
355
356    def remove(self, name, value):
357        """Remove a value from a PATH-like variable."""
358        assert not self._finalized
359        self._remove(name, value)
360        self._blankline()
361
362    def append(self, name, value):
363        """Add a value to a PATH-like variable. Rarely used, see prepend()."""
364        assert not self._finalized
365        name = self.normalize_key(name)
366        if self.get(name, None):
367            self._remove(name, value)
368            self._actions.append(Append(name, value, self._join))
369        else:
370            self._actions.append(Set(name, value))
371        self._blankline()
372
373    def prepend(self, name, value):
374        """Add a value to the beginning of a PATH-like variable."""
375        assert not self._finalized
376        name = self.normalize_key(name)
377        if self.get(name, None):
378            self._remove(name, value)
379            self._actions.append(Prepend(name, value, self._join))
380        else:
381            self._actions.append(Set(name, value))
382        self._blankline()
383
384    def echo(self, value='', newline=True):
385        """Echo a value to the terminal."""
386        # echo() deliberately ignores self._finalized.
387        self._actions.append(Echo(value, newline))
388        if value:
389            self._blankline()
390
391    def comment(self, comment):
392        """Add a comment to the init script."""
393        # comment() deliberately ignores self._finalized.
394        self._actions.append(Comment(comment))
395        self._blankline()
396
397    def command(self, command, exit_on_error=True):
398        """Run a command."""
399        # command() deliberately ignores self._finalized.
400        self._actions.append(Command(command, exit_on_error=exit_on_error))
401        self._blankline()
402
403    def doctor(self):
404        """Run 'pw doctor'."""
405        self._actions.append(Doctor())
406
407    def function(self, name, body):
408        """Define a function."""
409        assert not self._finalized
410        self._actions.append(Command(name, body))
411        self._blankline()
412
413    def _blankline(self):
414        self._actions.append(BlankLine())
415
416    def finalize(self):
417        """Run cleanup at the end of environment setup."""
418        assert not self._finalized
419        self._finalized = True
420        self._actions.append(Hash())
421        self._blankline()
422
423        if not self._windows:
424            buf = StringIO()
425            self.write_deactivate(buf)
426            self._actions.append(Function('_pw_deactivate', buf.getvalue()))
427            self._blankline()
428
429    def accept(self, visitor):
430        for action in self._actions:
431            action.accept(visitor)
432
433    def gni(self, outs, project_root):
434        gni_visitor.GNIVisitor(project_root).serialize(self, outs)
435
436    def json(self, outs):
437        json_visitor.JSONVisitor().serialize(self, outs)
438
439    def write(self, outs):
440        if self._windows:
441            visitor = batch_visitor.BatchVisitor(pathsep=self._pathsep)
442        else:
443            visitor = shell_visitor.ShellVisitor(pathsep=self._pathsep)
444        visitor.serialize(self, outs)
445
446    def write_deactivate(self, outs):
447        if self._windows:
448            return
449        visitor = shell_visitor.DeactivateShellVisitor(pathsep=self._pathsep)
450        visitor.serialize(self, outs)
451
452    @contextlib.contextmanager
453    def __call__(self, export=True):
454        """Set environment as if this was written to a file and sourced.
455
456        Within this context os.environ is updated with the environment
457        defined by this object. If export is False, os.environ is not updated,
458        but in both cases the updated environment is yielded.
459
460        On exit, previous environment is restored. See contextlib documentation
461        for details on how this function is structured.
462
463        Args:
464          export(bool): modify the environment of the running process (and
465            thus, its subprocesses)
466
467        Yields the new environment object.
468        """
469        try:
470            if export:
471                orig_env = os.environ.copy()
472                env = os.environ
473            else:
474                env = os.environ.copy()
475
476            apply = apply_visitor.ApplyVisitor(pathsep=self._pathsep)
477            apply.apply(self, env)
478
479            yield env
480
481        finally:
482            if export:
483                for key in set(os.environ):
484                    try:
485                        os.environ[key] = orig_env[key]
486                    except KeyError:
487                        del os.environ[key]
488                for key in set(orig_env) - set(os.environ):
489                    os.environ[key] = orig_env[key]
490
491    def get(self, key, default=None):
492        """Get the value of a variable within context of this object."""
493        key = self.normalize_key(key)
494        with self(export=False) as env:
495            return env.get(key, default)
496
497    def __getitem__(self, key):
498        """Get the value of a variable within context of this object."""
499        key = self.normalize_key(key)
500        with self(export=False) as env:
501            return env[key]
502