File: service.go

package info (click to toggle)
golang-github-viant-toolbox 0.33.2-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 1,280 kB
  • sloc: makefile: 16
file content (277 lines) | stat: -rw-r--r-- 7,012 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
package ssh

import (
	"bytes"
	"fmt"
	"github.com/pkg/errors"
	"github.com/viant/toolbox/cred"
	"github.com/viant/toolbox/storage"
	"golang.org/x/crypto/ssh"
	"io"
	"net"
	"os"
	"path"
	"strings"
	"sync"
	"time"
)

type (
	//Service represents ssh service
	Service interface {
		//Service returns a service wrapper
		Client() *ssh.Client

		//OpenMultiCommandSession opens multi command session
		OpenMultiCommandSession(config *SessionConfig) (MultiCommandSession, error)

		//Run runs supplied command
		Run(command string) error

		//Upload uploads provided content to specified destination
		//Deprecated: please consider using https://github.com/viant/afs/tree/master/scp
		Upload(destination string, mode os.FileMode, content []byte) error

		//Download downloads content from specified source.
		//Deprecated: please consider using https://github.com/viant/afs/tree/master/scp
		Download(source string) ([]byte, error)

		//OpenTunnel opens a tunnel between local to remote for network traffic.
		OpenTunnel(localAddress, remoteAddress string) error

		NewSession() (*ssh.Session, error)

		Close() error
	}
)

//service represnt SSH service
type service struct {
	host           string
	client         *ssh.Client
	forwarding     []*Tunnel
	replayCommands *ReplayCommands
	recordSession  bool
	config         *ssh.ClientConfig
}

//Service returns undelying ssh Service
func (c *service) Client() *ssh.Client {
	return c.client
}

//Service returns undelying ssh Service
func (c *service) NewSession() (*ssh.Session, error) {
	return c.client.NewSession()
}

//MultiCommandSession create a new MultiCommandSession
func (c *service) OpenMultiCommandSession(config *SessionConfig) (MultiCommandSession, error) {
	return newMultiCommandSession(c, config, c.replayCommands, c.recordSession)
}

func (c *service) Run(command string) error {
	session, err := c.client.NewSession()
	if err != nil {
		panic("failed to create session: " + err.Error())
	}
	defer session.Close()
	return session.Run(command)
}

func (c *service) transferData(payload []byte, createFileCmd string, writer io.Writer, errors chan error, waitGroup *sync.WaitGroup) {
	const endSequence = "\x00"
	defer waitGroup.Done()
	_, err := fmt.Fprint(writer, createFileCmd)
	if err != nil {
		errors <- err
		return
	}
	_, err = io.Copy(writer, bytes.NewReader(payload))
	if err != nil {
		errors <- err
		return
	}
	if _, err = fmt.Fprint(writer, endSequence); err != nil {
		errors <- err
		return
	}
}

type Errors chan error

func (e Errors) GetError() error {
	select {
	case err := <-e:
		return err
	case <-time.After(time.Millisecond):
	}
	return nil
}

const operationSuccessful = 0

func checkOutput(reader io.Reader, errorChannel Errors) {
	writer := new(bytes.Buffer)
	io.Copy(writer, reader)
	if writer.Len() > 1 {
		data := writer.Bytes()
		if data[1] == operationSuccessful {
			return
		} else if len(data) > 2 {
			errorChannel <- errors.New(string(data[2:]))
		}
	}
}

//Upload uploads passed in content into remote destination
func (c *service) Upload(destination string, mode os.FileMode, content []byte) (err error) {
	err = c.upload(destination, mode, content)

	if err != nil {
		if strings.Contains(err.Error(), "No such file or directory") {
			dir, _ := path.Split(destination)
			c.Run("mkdir -p " + dir)
			return c.upload(destination, mode, content)
		} else if strings.Contains(err.Error(), "handshake") || strings.Contains(err.Error(), "connection") {

			time.Sleep(500 * time.Millisecond)
			fmt.Printf("got error %v\n", err)
			c.Reconnect()
			return c.upload(destination, mode, content)
		}
	}
	return err
}

func (c *service) getSession() (*ssh.Session, error) {
	return c.client.NewSession()
}

//Upload uploads passed in content into remote destination
func (c *service) upload(destination string, mode os.FileMode, content []byte) (err error) {
	dir, file := path.Split(destination)
	if mode == 0 {
		mode = 0644
	}
	waitGroup := &sync.WaitGroup{}
	waitGroup.Add(1)
	if strings.HasPrefix(file, "/") {
		file = string(file[1:])
	}
	session, err := c.getSession()
	if err != nil {
		return err
	}

	writer, err := session.StdinPipe()
	if err != nil {
		return errors.Wrap(err, "failed to acquire stdin")
	}
	defer writer.Close()

	var transferError Errors = make(chan error, 1)
	defer close(transferError)
	var sessionError Errors = make(chan error, 1)
	defer close(sessionError)
	output, err := session.StdoutPipe()
	if err != nil {
		return errors.Wrap(err, "failed to acquire stdout")
	}
	go checkOutput(output, sessionError)

	if mode >= 01000 {
		mode = storage.DefaultFileMode
	}
	fileMode := string(fmt.Sprintf("C%04o", mode)[:5])
	createFileCmd := fmt.Sprintf("%v %d %s\n", fileMode, len(content), file)
	go c.transferData(content, createFileCmd, writer, transferError, waitGroup)
	scpCommand := "scp -qtr " + dir
	err = session.Start(scpCommand)
	if err != nil {
		return err
	}
	waitGroup.Wait()
	writerErr := writer.Close()
	if err := sessionError.GetError(); err != nil {
		return err
	}
	if err := transferError.GetError(); err != nil {
		return err
	}
	if err = session.Wait(); err != nil {
		if err := sessionError.GetError(); err != nil {
			return err
		}
		return err
	}
	return writerErr
}

//Download download passed source file from remote host.
func (c *service) Download(source string) ([]byte, error) {
	session, err := c.client.NewSession()
	if err != nil {
		return nil, err
	}
	defer session.Close()
	return session.Output(fmt.Sprintf("cat %s", source))
}

//Host returns client host
func (c *service) Host() string {
	return c.host
}

//Close closes service
func (c *service) Close() error {
	if len(c.forwarding) > 0 {
		for _, forwarding := range c.forwarding {
			_ = forwarding.Close()
		}
	}
	return c.client.Close()
}

//Reconnect client
func (c *service) Reconnect() error {
	return c.connect()
}

//OpenTunnel tunnels data between localAddress and remoteAddress on ssh connection
func (c *service) OpenTunnel(localAddress, remoteAddress string) error {
	local, err := net.Listen("tcp", localAddress)
	if err != nil {
		return errors.Wrap(err, fmt.Sprintf("failed to listen on local: %v %v", localAddress))
	}
	var forwarding = NewForwarding(c.client, remoteAddress, local)
	if len(c.forwarding) == 0 {
		c.forwarding = make([]*Tunnel, 0)
	}
	c.forwarding = append(c.forwarding, forwarding)
	go forwarding.Handle()
	return nil
}

func (c *service) connect() (err error) {
	if c.client, err = ssh.Dial("tcp", c.host, c.config); err != nil {
		return errors.Wrap(err, fmt.Sprintf("failed to dial %v: %s", c.host))
	}
	return nil
}

//NewService create a new ssh service, it takes host port and authentication config
func NewService(host string, port int, authConfig *cred.Config) (Service, error) {
	if authConfig == nil {
		authConfig = &cred.Config{}
	}
	clientConfig, err := authConfig.ClientConfig()
	if err != nil {
		return nil, err
	}
	var result = &service{
		host:   fmt.Sprintf("%s:%d", host, port),
		config: clientConfig,
	}
	return result, result.connect()
}