• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16package utils
17
18import (
19	"context"
20	"errors"
21	"fmt"
22	"github.com/pkg/sftp"
23	"github.com/sirupsen/logrus"
24	"golang.org/x/crypto/ssh"
25	"io"
26	"os"
27	"path/filepath"
28	"time"
29)
30
31func newSSHClient(addr string, user string, passwd string) (*ssh.Client, error) {
32	config := &ssh.ClientConfig{
33		User:            user,
34		Auth:            []ssh.AuthMethod{ssh.Password(passwd)},
35		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
36	}
37	config.SetDefaults()
38	return ssh.Dial("tcp", addr, config)
39}
40
41func RunCmdViaSSHContext(ctx context.Context, addr string, user string, passwd string, cmd string) (err error) {
42	ctx, fn := context.WithTimeout(ctx, 6*time.Hour)
43	defer fn()
44	if err := RunCmdViaSSHContextNoRetry(ctx, addr, user, passwd, cmd); err != nil {
45		if errors.Is(err, context.Canceled) {
46			return err
47		}
48		logrus.Errorf("exec cmd via SSH at %s failed: %v, try again...", addr, err)
49		return RunCmdViaSSHContextNoRetry(ctx, addr, user, passwd, cmd)
50	}
51	return nil
52}
53
54func RunCmdViaSSHContextNoRetry(ctx context.Context, addr string, user string, passwd string, cmd string) (err error) {
55	exit := make(chan struct{})
56	client, err := newSSHClient(addr, user, passwd)
57	if err != nil {
58		logrus.Errorf("new SSH client to %s err: %v", addr, err)
59		return err
60	}
61	defer client.Close()
62	session, err := client.NewSession()
63	if err != nil {
64		return err
65	}
66	defer func() {
67		select {
68		case <-ctx.Done():
69			err = ctx.Err()
70		default:
71		}
72	}()
73	defer close(exit)
74	go func() {
75		select {
76		case <-ctx.Done():
77		case <-exit:
78		}
79		session.Close()
80	}()
81	logrus.Infof("run at %s: %s", addr, cmd)
82	stdin, err := session.StdinPipe()
83	if err != nil {
84		return err
85	}
86	defer stdin.Close()
87	stdout, err := session.StdoutPipe()
88	if err != nil {
89		return err
90	}
91	stderr, err := session.StderrPipe()
92	if err != nil {
93		return err
94	}
95	if err := session.Shell(); err != nil {
96		return err
97	}
98	cmd = fmt.Sprintf("%s\nexit $?\n", cmd)
99	go stdin.Write([]byte(cmd))
100	go io.Copy(os.Stdout, stdout)
101	go io.Copy(os.Stderr, stderr)
102	fmt.Printf("[%s] exec at %s %s :\n", time.Now(), addr, cmd)
103	return session.Wait()
104}
105
106type Direct string
107
108const (
109	Download Direct = "download"
110	Upload   Direct = "upload"
111)
112
113func TransFileViaSSH(verb Direct, addr string, user string, passwd string, remoteFile string, localFile string) error {
114	c, err := newSSHClient(addr, user, passwd)
115	if err != nil {
116		logrus.Errorf("new SSH client to %s err: %v", addr, err)
117		return err
118	}
119	defer c.Close()
120	client, err := sftp.NewClient(c)
121	if err != nil {
122		logrus.Errorf("new SFTP client to %s err: %v", addr, err)
123		return err
124	}
125	defer client.Close()
126	var prep string
127	var src, dst io.ReadWriteCloser
128	if verb == Download {
129		prep = "to"
130		if src, err = client.Open(remoteFile); err != nil {
131			return fmt.Errorf("open remote file %s at %s err: %v", remoteFile, addr, err)
132		}
133		defer src.Close()
134		os.RemoveAll(localFile)
135		os.MkdirAll(filepath.Dir(localFile), 0755)
136		if dst, err = os.Create(localFile); err != nil {
137			return fmt.Errorf("create local file err: %v", err)
138		}
139		defer dst.Close()
140	} else {
141		prep = "from"
142		if src, err = os.Open(localFile); err != nil {
143			return fmt.Errorf("open local file err: %v", err)
144		}
145		defer src.Close()
146		client.Remove(remoteFile)
147		client.MkdirAll(filepath.Dir(remoteFile))
148		if dst, err = client.Create(remoteFile); err != nil {
149			return fmt.Errorf("create remote file %s at %s err: %v", remoteFile, addr, err)
150		}
151		defer dst.Close()
152	}
153	logrus.Infof("%sing %s at %s %s %s...", verb, remoteFile, addr, prep, localFile)
154	t1 := time.Now()
155	n, err := io.CopyBuffer(dst, src, make([]byte, 32*1024*1024))
156	if err != nil {
157		logrus.Errorf("%s %s at %s %s %s err: %v", verb, remoteFile, addr, prep, localFile, err)
158		return err
159	}
160	t2 := time.Now()
161	cost := t2.Sub(t1).Seconds()
162	logrus.Infof("%s %s at %s %s %s done, size: %d cost: %.2fs speed: %.2fMB/s", verb, remoteFile, addr, prep, localFile, n, cost, float64(n)/cost/1024/1024)
163	return nil
164}
165