1// Copyright 2021 Google Inc. All rights reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// run_with_timeout is a utility that can kill a wrapped command after a configurable timeout, 16// optionally running a command to collect debugging information first. 17 18package main 19 20import ( 21 "flag" 22 "fmt" 23 "io" 24 "os" 25 "os/exec" 26 "sync" 27 "syscall" 28 "time" 29) 30 31var ( 32 timeout = flag.Duration("timeout", 0, "time after which to kill command (example: 60s)") 33 onTimeoutCmd = flag.String("on_timeout", "", "command to run with `PID=<pid> sh -c` after timeout.") 34) 35 36func usage() { 37 fmt.Fprintf(os.Stderr, "usage: %s [--timeout N] [--on_timeout CMD] -- command [args...]\n", os.Args[0]) 38 flag.PrintDefaults() 39 fmt.Fprintln(os.Stderr, "run_with_timeout is a utility that can kill a wrapped command after a configurable timeout,") 40 fmt.Fprintln(os.Stderr, "optionally running a command to collect debugging information first.") 41 42 os.Exit(2) 43} 44 45func main() { 46 flag.Usage = usage 47 flag.Parse() 48 49 if flag.NArg() < 1 { 50 fmt.Fprintln(os.Stderr, "command is required") 51 usage() 52 } 53 54 err := runWithTimeout(flag.Arg(0), flag.Args()[1:], *timeout, *onTimeoutCmd, 55 os.Stdin, os.Stdout, os.Stderr) 56 if err != nil { 57 if exitErr, ok := err.(*exec.ExitError); ok { 58 fmt.Fprintln(os.Stderr, "process exited with error:", exitErr.Error()) 59 } else { 60 fmt.Fprintln(os.Stderr, "error:", err.Error()) 61 } 62 os.Exit(1) 63 } 64} 65 66// concurrentWriter wraps a writer to make it thread-safe to call Write. 67type concurrentWriter struct { 68 w io.Writer 69 sync.Mutex 70} 71 72// Write writes the data to the wrapped writer with a lock to allow for concurrent calls. 73func (c *concurrentWriter) Write(data []byte) (n int, err error) { 74 c.Lock() 75 defer c.Unlock() 76 if c.w == nil { 77 return 0, nil 78 } 79 return c.w.Write(data) 80} 81 82// Close ends the concurrentWriter, causing future calls to Write to be no-ops. It does not close 83// the underlying writer. 84func (c *concurrentWriter) Close() { 85 c.Lock() 86 defer c.Unlock() 87 c.w = nil 88} 89 90func runWithTimeout(command string, args []string, timeout time.Duration, onTimeoutCmdStr string, 91 stdin io.Reader, stdout, stderr io.Writer) error { 92 cmd := exec.Command(command, args...) 93 94 // Wrap the writers in a locking writer so that cmd and onTimeoutCmd don't try to write to 95 // stdout or stderr concurrently. 96 concurrentStdout := &concurrentWriter{w: stdout} 97 concurrentStderr := &concurrentWriter{w: stderr} 98 defer concurrentStdout.Close() 99 defer concurrentStderr.Close() 100 101 cmd.Stdin, cmd.Stdout, cmd.Stderr = stdin, concurrentStdout, concurrentStderr 102 err := cmd.Start() 103 if err != nil { 104 return err 105 } 106 107 // waitCh will signal the subprocess exited. 108 waitCh := make(chan error) 109 go func() { 110 waitCh <- cmd.Wait() 111 }() 112 113 // timeoutCh will signal the subprocess timed out if timeout was set. 114 var timeoutCh <-chan time.Time = make(chan time.Time) 115 if timeout > 0 { 116 timeoutCh = time.After(timeout) 117 } 118 119 select { 120 case err := <-waitCh: 121 if exitErr, ok := err.(*exec.ExitError); ok { 122 return fmt.Errorf("process exited with error: %w", exitErr) 123 } 124 return err 125 case <-timeoutCh: 126 // Continue below. 127 } 128 129 // Process timed out before exiting. 130 defer cmd.Process.Signal(syscall.SIGKILL) 131 132 if onTimeoutCmdStr != "" { 133 onTimeoutCmd := exec.Command("sh", "-c", onTimeoutCmdStr) 134 onTimeoutCmd.Stdin, onTimeoutCmd.Stdout, onTimeoutCmd.Stderr = stdin, concurrentStdout, concurrentStderr 135 onTimeoutCmd.Env = append(os.Environ(), fmt.Sprintf("PID=%d", cmd.Process.Pid)) 136 err := onTimeoutCmd.Run() 137 if err != nil { 138 return fmt.Errorf("on_timeout command %q exited with error: %w", onTimeoutCmdStr, err) 139 } 140 } 141 142 return fmt.Errorf("timed out after %s", timeout.String()) 143} 144