import sys
from rdkit import Chem
from rdkit import six
import threading
import multiprocessing

# this just tests some threading stuff to ensure it doesn't crash with python
#  releasing the GIL smarts are recursive...
ref_sdf = '\n     RDKit          3D\n\n 22 23  0  0  1  0  0  0  0  0999 V2000\n   -6.1917   -1.9517    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n   -5.0664   -1.3009    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n   -3.9401   -1.9499    0.0000 O   0  0  0  0  0  0  0  0  0  0  0  0\n   -2.8148   -1.2991    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n   -1.6885   -1.9483    0.0000 N   0  0  0  0  0  0  0  0  0  0  0  0\n   -0.5632   -1.2973    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n   -0.5642    0.0027    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n   -1.6905    0.6517    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n   -1.6916    1.9517    0.0000 N   0  0  0  0  0  0  0  0  0  0  0  0\n   -2.8158    0.0009    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n   -3.9422    0.6501    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n   -5.0685    1.2991    0.0000 N   0  0  0  0  0  0  0  0  0  0  0  0\n    0.5632   -1.9465    0.0000 N   0  0  0  0  0  0  0  0  0  0  0  0\n    1.6885   -1.2955    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n    1.6874    0.0046    0.0000 O   0  0  0  0  0  0  0  0  0  0  0  0\n    2.8148   -1.9447    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n    3.9401   -1.2936    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n    3.9391    0.0064    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n    5.0644    0.6572    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n    6.1907    0.0082    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n    6.1917   -1.2918    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n    5.0664   -1.9429    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n  1  2  1  0\n  2  3  1  0\n  3  4  1  0\n  4 10  2  0\n  4  5  1  0\n  5  6  2  0\n  6  7  1  0\n  6 13  1  0\n  7  8  2  0\n  8  9  1  0\n  8 10  1  0\n 10 11  1  0\n 11 12  3  0\n 13 14  1  0\n 14 15  2  0\n 14 16  1  0\n 16 17  1  0\n 17 22  1  0\n 17 18  2  0\n 18 19  1  0\n 19 20  2  0\n 20 21  1  0\n 21 22  2  0\nM  END'
ref_mol = Chem.MolFromMolBlock(ref_sdf)

core_smarts = '[#6]-!@[#6]-!@[#8]-!@[#6]:1:[#6](-!@[#6]#!@[#7]):[#6](-!@[#7]):[#6]:[#6](-!@[#7]-!@[#6](-!@[#6]-!@[#6]:2:[#6]:[#6]:[#6]:[#6]:[#6]:2)=!@[#8]):[#7]:1'
if ref_mol is None:
  raise ValueError('Bad ref structure')
core_mol = Chem.MolFromSmarts(core_smarts)
if core_mol is None:
  raise ValueError('Bad core structure')

expected = {}


def runner(func, args):
  if args:
    res = getattr(ref_mol, func)(args)
  else:
    res = getattr(ref_mol, func)()
  if func in expected:
    assert res == expected[func], "Got %r expected %r" % (ers, expected[func])
  return res


funcs = ["GetSubstructMatch", "GetSubstructMatches", "HasSubstructMatch"]

# get the expected results from the non-thread version
for func in funcs:
  expected[func] = runner(func, core_mol)

nthreads = int(multiprocessing.cpu_count() * 100 / 4)  # 100 threads per cpu
threads = []
for i in range(0, nthreads):
  for func in funcs:
    t = threading.Thread(target=runner, args=(func, core_mol))
    t.start()
    threads.append(t)
  t = threading.Thread(target=runner, args=("ToBinary", None))
  t.start()
  threads.append(t)
for t in threads:
  t.join()


def LogError():
  i = 0
  while 1:
    if i == 10:
      break
    i += 1
    Chem.LogErrorMsg(str(i) + ":: My dog has fleas")


def LogWarning():
  i = 0
  while 1:
    if i == 10:
      break
    i += 1
    Chem.LogWarningMsg(str(i) + ":: All good boys to fine")

  # this spews a ton of logging info...
  #  that is all intermingled...


if 0:
  nthreads = int(multiprocessing.cpu_count())
  threads = []
  for i in range(0, nthreads):
    for func in funcs:
      if i % 2 == 0:
        t = threading.Thread(target=LogError)
      else:
        t = threading.Thread(target=LogWarning)
      t.start()
      threads.append(t)
    t = threading.Thread(target=LogWarning)
    t.start()
    threads.append(t)

  for t in threads:
    t.join()

Chem.WrapLogs()

err = sys.stderr
stringio = sys.stderr = six.StringIO()

# now the errors should be synchronized...
nthreads = int(multiprocessing.cpu_count())
threads = []
for i in range(0, nthreads):
  for func in funcs:
    if i % 2 == 0:
      t = threading.Thread(target=LogError)
    else:
      t = threading.Thread(target=LogWarning)
    t.start()
    threads.append(t)
  t = threading.Thread(target=LogWarning)
  t.start()
  threads.append(t)

for t in threads:
  t.join()
sys.stderr = err

stringio = sys.stderr = six.StringIO()
LogWarning()
LogError()
sys.stderr = err
assert "WARNING" in stringio.getvalue()
assert "ERROR" in stringio.getvalue()
assert stringio.getvalue().count("WARNING") == 10
assert stringio.getvalue().count("ERROR") == 10
