File: stream.rs

package info (click to toggle)
aws-crt-python 0.20.4%2Bdfsg-1~bpo12%2B1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm-backports
  • size: 72,656 kB
  • sloc: ansic: 381,805; python: 23,008; makefile: 6,251; sh: 4,536; cpp: 699; ruby: 208; java: 77; perl: 73; javascript: 46; xml: 11
file content (134 lines) | stat: -rw-r--r-- 3,662 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use std::{
    io,
    pin::Pin,
    sync::{Arc, Mutex},
    task::{Context, Poll},
};
use tokio::{
    io::{AsyncRead, AsyncWrite, ReadBuf},
    net::TcpStream,
};

type ReadFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context, &mut ReadBuf) -> Poll<io::Result<()>>>;
type WriteFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context, &[u8]) -> Poll<io::Result<usize>>>;
type ShutdownFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context) -> Poll<io::Result<()>>>;

#[derive(Default)]
struct OverrideMethods {
    next_read: Option<ReadFn>,
    next_write: Option<WriteFn>,
    next_shutdown: Option<ShutdownFn>,
}

#[derive(Default)]
pub struct Overrides(Mutex<OverrideMethods>);

impl Overrides {
    pub fn next_read(&self, input: Option<ReadFn>) {
        if let Ok(mut overrides) = self.0.lock() {
            overrides.next_read = input;
        }
    }

    pub fn next_write(&self, input: Option<WriteFn>) {
        if let Ok(mut overrides) = self.0.lock() {
            overrides.next_write = input;
        }
    }

    pub fn next_shutdown(&self, input: Option<ShutdownFn>) {
        if let Ok(mut overrides) = self.0.lock() {
            overrides.next_shutdown = input;
        }
    }

    pub fn is_consumed(&self) -> bool {
        if let Ok(overrides) = self.0.lock() {
            overrides.next_read.is_none()
                && overrides.next_write.is_none()
                && overrides.next_shutdown.is_none()
        } else {
            false
        }
    }
}

unsafe impl Send for Overrides {}
unsafe impl Sync for Overrides {}

pub struct TestStream {
    stream: TcpStream,
    overrides: Arc<Overrides>,
}

impl TestStream {
    pub fn new(stream: TcpStream) -> Self {
        let overrides = Arc::new(Overrides::default());
        Self { stream, overrides }
    }

    pub fn overrides(&self) -> Arc<Overrides> {
        self.overrides.clone()
    }
}

impl AsyncRead for TestStream {
    fn poll_read(
        self: Pin<&mut Self>,
        ctx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        let s = self.get_mut();
        let stream = Pin::new(&mut s.stream);
        let action = match s.overrides.0.lock() {
            Ok(mut overrides) => overrides.next_read.take(),
            _ => None,
        };
        if let Some(f) = action {
            (f)(stream, ctx, buf)
        } else {
            stream.poll_read(ctx, buf)
        }
    }
}

impl AsyncWrite for TestStream {
    fn poll_write(
        self: Pin<&mut Self>,
        ctx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        let s = self.get_mut();
        let stream = Pin::new(&mut s.stream);
        let action = match s.overrides.0.lock() {
            Ok(mut overrides) => overrides.next_write.take(),
            _ => None,
        };
        if let Some(f) = action {
            (f)(stream, ctx, buf)
        } else {
            stream.poll_write(ctx, buf)
        }
    }

    fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Pin::new(&mut self.stream).poll_flush(ctx)
    }

    fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
        let s = self.get_mut();
        let stream = Pin::new(&mut s.stream);
        let action = match s.overrides.0.lock() {
            Ok(mut overrides) => overrides.next_shutdown.take(),
            _ => None,
        };
        if let Some(f) = action {
            (f)(stream, ctx)
        } else {
            stream.poll_shutdown(ctx)
        }
    }
}