File: eventstream_test.go

package info (click to toggle)
golang-github-aws-aws-sdk-go-v2 1.24.1-2~bpo12%2B1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm-backports
  • size: 554,032 kB
  • sloc: java: 15,941; makefile: 419; sh: 175
file content (201 lines) | stat: -rw-r--r-- 5,495 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
//go:build integration
// +build integration

package transcribestreaming

import (
	"bytes"
	"context"
	"encoding/base64"
	"flag"
	"fmt"
	"io"
	"os"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/service/internal/integrationtest"
	"github.com/aws/aws-sdk-go-v2/service/transcribestreaming"
	"github.com/aws/aws-sdk-go-v2/service/transcribestreaming/types"
)

var (
	audioFilename   string
	audioFormat     string
	audioLang       string
	audioSampleRate int
	audioFrameSize  int
	withDebug       bool
)

func init() {
	flag.BoolVar(&withDebug, "debug", false, "Include debug logging with test.")
	flag.StringVar(&audioFilename, "audio-file", "", "Audio file filename to perform test with.")
	flag.StringVar(&audioLang, "audio-lang", string(types.LanguageCodeEnUs), "Language of audio speech.")
	flag.StringVar(&audioFormat, "audio-format", string(types.MediaEncodingPcm), "Format of audio.")
	flag.IntVar(&audioSampleRate, "audio-sample", 16000, "Sample rate of the audio.")
	flag.IntVar(&audioFrameSize, "audio-frame", 15*1024, "Size of frames of audio uploaded.")
}

func TestInteg_StartStreamTranscription(t *testing.T) {
	var audio io.Reader
	if len(audioFilename) != 0 {
		audioFile, err := os.Open(audioFilename)
		if err != nil {
			t.Fatalf("expect to open file, %v", err)
		}
		defer audioFile.Close()
		audio = audioFile
	} else {
		b, err := base64.StdEncoding.DecodeString(
			`UklGRjzxPQBXQVZFZm10IBAAAAABAAEAgD4AAAB9AAACABAAZGF0YVTwPQAAAAAAAAAAAAAAAAD//wIA/f8EAA==`,
		)
		if err != nil {
			t.Fatalf("expect decode audio bytes, %v", err)
		}
		audio = bytes.NewReader(b)
	}

	cfg, _ := integrationtest.LoadConfigWithDefaultRegion("us-west-2")

	ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancelFn()

	client := transcribestreaming.NewFromConfig(cfg, func(o *transcribestreaming.Options) {
		o.ClientLogMode = aws.LogSigning
	})
	resp, err := client.StartStreamTranscription(ctx, &transcribestreaming.StartStreamTranscriptionInput{
		LanguageCode:         types.LanguageCode(audioLang),
		MediaEncoding:        types.MediaEncoding(audioFormat),
		MediaSampleRateHertz: aws.Int32(int32(audioSampleRate)),
	})
	if err != nil {
		t.Fatalf("failed to start streaming, %v", err)
	}
	stream := resp.GetStream()
	defer stream.Close()

	go streamAudioFromReader(context.Background(), stream.Writer, audioFrameSize, audio)

	for event := range stream.Events() {
		switch e := event.(type) {
		case *types.TranscriptResultStreamMemberTranscriptEvent:
			t.Logf("got event, %v results", len(e.Value.Transcript.Results))
			for _, res := range e.Value.Transcript.Results {
				for _, alt := range res.Alternatives {
					t.Logf("* %s", aws.ToString(alt.Transcript))
				}
			}
		default:
			t.Fatalf("unexpected event, %T", event)
		}
	}

	if err := stream.Err(); err != nil {
		t.Fatalf("expect no error from stream, got %v", err)
	}
}

func TestInteg_StartStreamTranscription_contextClose(t *testing.T) {
	b, err := base64.StdEncoding.DecodeString(
		`UklGRjzxPQBXQVZFZm10IBAAAAABAAEAgD4AAAB9AAACABAAZGF0YVTwPQAAAAAAAAAAAAAAAAD//wIA/f8EAA==`,
	)
	if err != nil {
		t.Fatalf("expect decode audio bytes, %v", err)
	}
	audio := bytes.NewReader(b)

	cfg, _ := integrationtest.LoadConfigWithDefaultRegion("us-west-2")

	client := transcribestreaming.NewFromConfig(cfg)
	resp, err := client.StartStreamTranscription(context.Background(), &transcribestreaming.StartStreamTranscriptionInput{
		LanguageCode:         types.LanguageCodeEnUs,
		MediaEncoding:        types.MediaEncodingPcm,
		MediaSampleRateHertz: aws.Int32(16000),
	})
	if err != nil {
		t.Fatalf("failed to start streaming, %v", err)
	}

	stream := resp.GetStream()
	defer stream.Close()

	ctx, cancelFn := context.WithCancel(context.Background())
	defer cancelFn()

	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		err := streamAudioFromReader(ctx, stream.Writer, audioFrameSize, audio)
		if err == nil {
			t.Errorf("expect error")
		}
		if e, a := "context canceled", err.Error(); !strings.Contains(a, e) {
			t.Errorf("expect %q error in %q", e, a)
		}
		wg.Done()
	}()

	cancelFn()

Loop:
	for {
		select {
		case <-ctx.Done():
			break Loop
		case event, ok := <-stream.Events():
			if !ok {
				break Loop
			}
			switch e := event.(type) {
			case *types.TranscriptResultStreamMemberTranscriptEvent:
				t.Logf("got event, %v results", len(e.Value.Transcript.Results))
				for _, res := range e.Value.Transcript.Results {
					for _, alt := range res.Alternatives {
						t.Logf("* %s", aws.ToString(alt.Transcript))
					}
				}
			default:
				t.Fatalf("unexpected event, %T", event)
			}
		}
	}

	wg.Wait()

	if err := stream.Err(); err != nil {
		t.Fatalf("expect no error from stream, got %v", err)
	}
}

func streamAudioFromReader(ctx context.Context, stream transcribestreaming.AudioStreamWriter, frameSize int, input io.Reader) (err error) {
	defer func() {
		if closeErr := stream.Close(); closeErr != nil && err == nil {
			err = fmt.Errorf("failed to close stream, %v", closeErr)
		}
	}()

	frame := make([]byte, frameSize)
	for {
		var n int
		n, err = input.Read(frame)
		if n > 0 {
			err = stream.Send(ctx, &types.AudioStreamMemberAudioEvent{Value: types.AudioEvent{
				AudioChunk: frame[:n],
			}})
			if err != nil {
				return fmt.Errorf("failed to send audio event, %v", err)
			}
		}

		if err == io.EOF {
			return nil
		}
		if err != nil {
			return fmt.Errorf("failed to read audio, %v", err)
		}
	}
}