1/* 2 * Copyright (c) 2022 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16package utils 17 18import ( 19 "context" 20 "errors" 21 "fmt" 22 "github.com/pkg/sftp" 23 "github.com/sirupsen/logrus" 24 "golang.org/x/crypto/ssh" 25 "io" 26 "os" 27 "path/filepath" 28 "time" 29) 30 31func newSSHClient(addr string, user string, passwd string) (*ssh.Client, error) { 32 config := &ssh.ClientConfig{ 33 User: user, 34 Auth: []ssh.AuthMethod{ssh.Password(passwd)}, 35 HostKeyCallback: ssh.InsecureIgnoreHostKey(), 36 } 37 config.SetDefaults() 38 return ssh.Dial("tcp", addr, config) 39} 40 41func RunCmdViaSSHContext(ctx context.Context, addr string, user string, passwd string, cmd string) (err error) { 42 ctx, fn := context.WithTimeout(ctx, 6*time.Hour) 43 defer fn() 44 if err := RunCmdViaSSHContextNoRetry(ctx, addr, user, passwd, cmd); err != nil { 45 if errors.Is(err, context.Canceled) { 46 return err 47 } 48 logrus.Errorf("exec cmd via SSH at %s failed: %v, try again...", addr, err) 49 return RunCmdViaSSHContextNoRetry(ctx, addr, user, passwd, cmd) 50 } 51 return nil 52} 53 54func RunCmdViaSSHContextNoRetry(ctx context.Context, addr string, user string, passwd string, cmd string) (err error) { 55 exit := make(chan struct{}) 56 client, err := newSSHClient(addr, user, passwd) 57 if err != nil { 58 logrus.Errorf("new SSH client to %s err: %v", addr, err) 59 return err 60 } 61 defer client.Close() 62 session, err := client.NewSession() 63 if err != nil { 64 return err 65 } 66 defer func() { 67 select { 68 case <-ctx.Done(): 69 err = ctx.Err() 70 default: 71 } 72 }() 73 defer close(exit) 74 go func() { 75 select { 76 case <-ctx.Done(): 77 case <-exit: 78 } 79 session.Close() 80 }() 81 logrus.Infof("run at %s: %s", addr, cmd) 82 stdin, err := session.StdinPipe() 83 if err != nil { 84 return err 85 } 86 defer stdin.Close() 87 stdout, err := session.StdoutPipe() 88 if err != nil { 89 return err 90 } 91 stderr, err := session.StderrPipe() 92 if err != nil { 93 return err 94 } 95 if err := session.Shell(); err != nil { 96 return err 97 } 98 cmd = fmt.Sprintf("%s\nexit $?\n", cmd) 99 go stdin.Write([]byte(cmd)) 100 go io.Copy(os.Stdout, stdout) 101 go io.Copy(os.Stderr, stderr) 102 fmt.Printf("[%s] exec at %s %s :\n", time.Now(), addr, cmd) 103 return session.Wait() 104} 105 106type Direct string 107 108const ( 109 Download Direct = "download" 110 Upload Direct = "upload" 111) 112 113func TransFileViaSSH(verb Direct, addr string, user string, passwd string, remoteFile string, localFile string) error { 114 c, err := newSSHClient(addr, user, passwd) 115 if err != nil { 116 logrus.Errorf("new SSH client to %s err: %v", addr, err) 117 return err 118 } 119 defer c.Close() 120 client, err := sftp.NewClient(c) 121 if err != nil { 122 logrus.Errorf("new SFTP client to %s err: %v", addr, err) 123 return err 124 } 125 defer client.Close() 126 var prep string 127 var src, dst io.ReadWriteCloser 128 if verb == Download { 129 prep = "to" 130 if src, err = client.Open(remoteFile); err != nil { 131 return fmt.Errorf("open remote file %s at %s err: %v", remoteFile, addr, err) 132 } 133 defer src.Close() 134 os.RemoveAll(localFile) 135 os.MkdirAll(filepath.Dir(localFile), 0755) 136 if dst, err = os.Create(localFile); err != nil { 137 return fmt.Errorf("create local file err: %v", err) 138 } 139 defer dst.Close() 140 } else { 141 prep = "from" 142 if src, err = os.Open(localFile); err != nil { 143 return fmt.Errorf("open local file err: %v", err) 144 } 145 defer src.Close() 146 client.Remove(remoteFile) 147 client.MkdirAll(filepath.Dir(remoteFile)) 148 if dst, err = client.Create(remoteFile); err != nil { 149 return fmt.Errorf("create remote file %s at %s err: %v", remoteFile, addr, err) 150 } 151 defer dst.Close() 152 } 153 logrus.Infof("%sing %s at %s %s %s...", verb, remoteFile, addr, prep, localFile) 154 t1 := time.Now() 155 n, err := io.CopyBuffer(dst, src, make([]byte, 32*1024*1024)) 156 if err != nil { 157 logrus.Errorf("%s %s at %s %s %s err: %v", verb, remoteFile, addr, prep, localFile, err) 158 return err 159 } 160 t2 := time.Now() 161 cost := t2.Sub(t1).Seconds() 162 logrus.Infof("%s %s at %s %s %s done, size: %d cost: %.2fs speed: %.2fMB/s", verb, remoteFile, addr, prep, localFile, n, cost, float64(n)/cost/1024/1024) 163 return nil 164} 165