File: cache_driver.py

package info (click to toggle)
geneagrapher-core 0.1.4-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 508 kB
  • sloc: python: 941; makefile: 37
file content (99 lines) | stat: -rw-r--r-- 2,804 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""This example demonstrates a simple cache using Redis.

In order to run this, you will need to install the redis package into
your Python environment and run a Redis instance on your development
machine. You can accomplish this by doing
`poetry install --with examples`.

A couple notes:

1. This example does not specify any TTLs for data that is cached. A
   more complete implementation probably should do that.
2. This example does not take any command-line arguments to control
   the graph that is being built. It simply hardcodes a starting
   record ID. A general driver would want to accept the starting
   record IDs as input.

Running:
```
$ poetry run python cache_driver.py

# If you want to see nicer output and have jq installed.
$ poetry run python cache_driver.py | jq

# If you want to see the progress bar either redirect output to a file
# or /dev/null.
$ poetry run python cache_driver.py > /dev/null
```

"""

from geneagrapher_core.record import CacheResult, Record, RecordId
from geneagrapher_core.traverse import TraverseDirection, TraverseItem, build_graph

import asyncio
import json
import redis.asyncio as redis
import sys
from typing import Optional, Tuple


class RedisCache:
    def __init__(self):
        self.r = redis.Redis(host="localhost", port=6379, db=0)

    def key(self, id: RecordId):
        return f"ggrapher::{id}"

    async def get(self, id: RecordId) -> Tuple[CacheResult, Optional[Record]]:
        val = await self.r.get(self.key(id))

        if val is None:
            # Miss
            return (CacheResult.MISS, None)
        elif val == {}:
            # A null-value hit
            return (CacheResult.HIT, None)
        else:
            # General hit
            return (CacheResult.HIT, json.loads(val))

    async def set(self, id: RecordId, value: Optional[Record]) -> None:
        val = {} if value is None else value
        await self.r.set(self.key(id), json.dumps(val))


def display_progress(queued, doing, done):
    prefix = "Progress: "
    size = 60
    count = queued + doing + done

    x = int(size * done / count)
    y = int(size * doing / count)

    print(
        f"{prefix}[{u'█'*x}{u':'*y}{('.'*(size - x - y))}] {done}/{count}",
        end="\r",
        file=sys.stderr,
        flush=True,
    )


async def get_progress(
    tg: asyncio.TaskGroup, to_fetch: int, fetching: int, fetched: int
) -> None:
    display_progress(to_fetch, fetching, fetched)


if __name__ == "__main__":
    cache = RedisCache()
    ggraph = asyncio.run(
        build_graph(
            [TraverseItem(RecordId(18231), TraverseDirection.ADVISORS)],
            cache=cache,
            report_callback=get_progress,
        )
    )

    print(file=sys.stderr)  # this adds a newline to the progress bar
    print(json.dumps(ggraph))