1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
|
#------------------------------------------------------------------------------
# Copyright (c) 2018-2024, Nucleic Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
#------------------------------------------------------------------------------
import sys
import ast
from textwrap import dedent
import pytest
from utils import compile_source
from enaml.core.parser import parse
from .test_parser import validate_ast
FUNC_TEMPLATE =\
"""
import asyncio
async def fetch(query):
return query
{}
"""
TEST_SOURCE = {
'await': """
async def function(queries):
r = await query()
return r
""",
'await if': """
async def function(query):
result = await fetch(query)
if not query:
return
return result
""",
'if await': """
async def function(query):
if not query:
return
result = await fetch(query)
return result
""",
'await future': """
async def function(query):
f = fetch(query)
await f
return result
""",
'await subscript': """
async def function(query):
tasks = [fetch(query)]
await tasks[0]
return result
""",
'await attr': """
async def function(query):
class API:
search = fetch
api = API()
await api.fetch(query)
return result
""",
'await list comp': """
async def function(queries):
result = [await fetch(q) for q in queries]
return result
""",
'await dict comp': """
async def function(queries):
result = {i:await fetch(q) for i, q in enumerate(queries)}
return result
""",
'async for': """
async def function(queries):
results = []
async for r in fetch(queries):
results.append(r)
return result
""",
'async for or': """
async def function(queries):
results = []
async for r in queries or fetch(queries):
results.append(r)
return result
""",
'async for comp': """
async def function(queries):
results = []
async for r in [f for f in fetch(queries)]:
results.append(r)
return result
""",
'async for or comp': """
async def function(queries):
results = []
async for r in queries or [f for f in fetch(queries)]:
results.append(r)
return result
""",
'async with': """
async def function(query):
lock = asyncio.lock()
async with lock:
result = await fetch(query)
return result
""",
}
if sys.version_info >= (3, 6):
TEST_SOURCE.update({
'async for list comp': """
async def function(queries):
result = [r async for r in fetch(q)]
return result
""",
'async for if list comp': """
async def function(queries):
result = [r async for r in fetch(queries) if r]
return result
""",
})
if sys.version_info < (3, 7):
TEST_SOURCE.update({
'async not keyword': """
def function(queries):
async = False
return queries
""",
'await not keyword': """
def function(queries):
await = False
return queries
""",
})
@pytest.mark.parametrize('desc', TEST_SOURCE.keys())
def test_async(desc):
"""Async function with await list comp statement. """
src = FUNC_TEMPLATE.format(dedent(TEST_SOURCE[desc]))
# Ensure it's valid
py_ast = ast.parse(src)
enaml_ast = parse(src).body[0].ast
validate_ast(py_ast.body[0], enaml_ast.body[0], True)
validate_ast(py_ast.body[1], enaml_ast.body[1], True)
validate_ast(py_ast.body[2], enaml_ast.body[2], True)
def test_decl_async_func():
py_src = dedent("""
from enaml.core.declarative import d_func
from enaml.widgets.api import Window, Label
async def fetch(query):
return query
class MainWindow(Window):
@d_func
async def search(self, query):
result = await fetch(query)
return result
""")
enaml_src = dedent("""
from enaml.core.declarative import d_func
from enaml.widgets.api import Window, Label
async def fetch(query):
return query
enamldef MainWindow(Window):
async func search(query):
result = await fetch(query)
return result
enamldef CustomWindow(MainWindow):
async search => (query):
result = await fetch(query)
return result
""")
py_ast = ast.parse(py_src)
enaml_ast = parse(enaml_src)
validate_ast(py_ast.body[3].body[0],
enaml_ast.body[1].body[0].funcdef, True)
# Check override syntax
validate_ast(py_ast.body[3].body[0],
enaml_ast.body[2].body[0].funcdef, True)
# Make sure it compiles
CustomWindow = compile_source(enaml_src, 'CustomWindow')
|