1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
|
//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift Async Algorithms open source project
//
// Copyright (c) 2022 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
//
//===----------------------------------------------------------------------===//
import AsyncAlgorithms
public struct ManualClock: Clock {
public struct Step: DurationProtocol {
fileprivate var rawValue: Int
fileprivate init(_ rawValue: Int) {
self.rawValue = rawValue
}
public static func + (lhs: ManualClock.Step, rhs: ManualClock.Step) -> ManualClock.Step {
return .init(lhs.rawValue + rhs.rawValue)
}
public static func - (lhs: ManualClock.Step, rhs: ManualClock.Step) -> ManualClock.Step {
.init(lhs.rawValue - rhs.rawValue)
}
public static func / (lhs: ManualClock.Step, rhs: Int) -> ManualClock.Step {
.init(lhs.rawValue / rhs)
}
public static func * (lhs: ManualClock.Step, rhs: Int) -> ManualClock.Step {
.init(lhs.rawValue * rhs)
}
public static func / (lhs: ManualClock.Step, rhs: ManualClock.Step) -> Double {
Double(lhs.rawValue) / Double(rhs.rawValue)
}
public static func < (lhs: ManualClock.Step, rhs: ManualClock.Step) -> Bool {
lhs.rawValue < rhs.rawValue
}
public static var zero: ManualClock.Step { .init(0) }
public static func steps(_ amount: Int) -> Step {
return Step(amount)
}
}
public struct Instant: InstantProtocol, CustomStringConvertible {
public typealias Duration = Step
internal let rawValue: Int
internal init(_ rawValue: Int) {
self.rawValue = rawValue
}
public static func < (lhs: ManualClock.Instant, rhs: ManualClock.Instant) -> Bool {
return lhs.rawValue < rhs.rawValue
}
public func advanced(by duration: ManualClock.Step) -> ManualClock.Instant {
.init(rawValue + duration.rawValue)
}
public func duration(to other: ManualClock.Instant) -> ManualClock.Step {
.init(other.rawValue - rawValue)
}
public var description: String {
return "tick \(rawValue)"
}
}
fileprivate struct Wakeup {
let generation: Int
let continuation: UnsafeContinuation<Void, Error>
let deadline: Instant
}
fileprivate enum Scheduled: Hashable, Comparable, CustomStringConvertible {
case cancelled(Int)
case wakeup(Wakeup)
func hash(into hasher: inout Hasher) {
switch self {
case .cancelled(let generation):
hasher.combine(generation)
case .wakeup(let wakeup):
hasher.combine(wakeup.generation)
}
}
var description: String {
switch self {
case .cancelled: return "Cancelled wakeup"
case .wakeup(let wakeup): return "Wakeup at \(wakeup.deadline)"
}
}
static func == (_ lhs: Scheduled, _ rhs: Scheduled) -> Bool {
switch (lhs, rhs) {
case (.cancelled(let lhsGen), .cancelled(let rhsGen)):
return lhsGen == rhsGen
case (.cancelled(let lhsGen), .wakeup(let rhs)):
return lhsGen == rhs.generation
case (.wakeup(let lhs), .cancelled(let rhsGen)):
return lhs.generation == rhsGen
case (.wakeup(let lhs), .wakeup(let rhs)):
return lhs.generation == rhs.generation
}
}
static func < (lhs: ManualClock.Scheduled, rhs: ManualClock.Scheduled) -> Bool {
switch (lhs, rhs) {
case (.cancelled(let lhsGen), .cancelled(let rhsGen)):
return lhsGen < rhsGen
case (.cancelled(let lhsGen), .wakeup(let rhs)):
return lhsGen < rhs.generation
case (.wakeup(let lhs), .cancelled(let rhsGen)):
return lhs.generation < rhsGen
case (.wakeup(let lhs), .wakeup(let rhs)):
return lhs.generation < rhs.generation
}
}
var deadline: Instant? {
switch self {
case .cancelled: return nil
case .wakeup(let wakeup): return wakeup.deadline
}
}
func resume() {
switch self {
case .wakeup(let wakeup):
wakeup.continuation.resume()
default:
break
}
}
}
fileprivate struct State {
var generation = 0
var scheduled = Set<Scheduled>()
var now = Instant(0)
var hasSleepers = false
}
fileprivate let state = ManagedCriticalState(State())
public var now: Instant {
state.withCriticalRegion { $0.now }
}
public var minimumResolution: Step { return .zero }
public init() { }
fileprivate func cancel(_ generation: Int) {
state.withCriticalRegion { state -> UnsafeContinuation<Void, Error>? in
if let existing = state.scheduled.remove(.cancelled(generation)) {
switch existing {
case .wakeup(let wakeup):
return wakeup.continuation
default:
return nil
}
} else {
// insert the cancelled state for when it comes in to be scheduled as a wakeup
state.scheduled.insert(.cancelled(generation))
return nil
}
}?.resume(throwing: CancellationError())
}
var hasSleepers: Bool {
state.withCriticalRegion { $0.hasSleepers }
}
public func advance() {
let pending = state.withCriticalRegion { state -> Set<Scheduled> in
state.now = state.now.advanced(by: .steps(1))
let pending = state.scheduled.filter { item in
if let deadline = item.deadline {
return deadline <= state.now
} else {
return false
}
}
state.scheduled.subtract(pending)
if pending.count > 0 {
state.hasSleepers = false
}
return pending
}
for item in pending.sorted() {
item.resume()
}
}
public func advance(by steps: Step) {
for _ in 0..<steps.rawValue {
advance()
}
}
fileprivate func schedule(_ generation: Int, continuation: UnsafeContinuation<Void, Error>, deadline: Instant) {
let resumption = state.withCriticalRegion { state -> (UnsafeContinuation<Void, Error>, Result<Void, Error>)? in
let wakeup = Wakeup(generation: generation, continuation: continuation, deadline: deadline)
if let existing = state.scheduled.remove(.wakeup(wakeup)) {
switch existing {
case .wakeup:
fatalError()
case .cancelled:
// dont bother adding it back because it has been cancelled before we got here
return (continuation, .failure(CancellationError()))
}
} else {
// there is no cancelled placeholder so let it run free
if deadline > state.now {
// the deadline is in the future so run it then
state.hasSleepers = true
state.scheduled.insert(.wakeup(wakeup))
return nil
} else {
// the deadline is now or in the past so run it immediately
return (continuation, .success(()))
}
}
}
if let resumption = resumption {
resumption.0.resume(with: resumption.1)
}
}
public func sleep(until deadline: Instant, tolerance: Step? = nil) async throws {
let generation = state.withCriticalRegion { state -> Int in
defer { state.generation += 1 }
return state.generation
}
try await withTaskCancellationHandler {
try await withUnsafeThrowingContinuation { continuation in
schedule(generation, continuation: continuation, deadline: deadline)
}
} onCancel: {
cancel(generation)
}
}
}
|