From: Taku Kudo <taku@google.com>
Date: Fri, 5 Aug 2022 14:47:02 +0900
Subject: automatically detect the number of CPUs in batch processing.

Signed-off-by: Kentaro Hayashi <kenhys@gmail.com>
---
 python/src/sentencepiece/__init__.py            | 27 +++++++++++++--------
 python/src/sentencepiece/sentencepiece.i        | 32 ++++++++++++++++---------
 python/src/sentencepiece/sentencepiece_wrap.cxx |  5 +++-
 python/test/sentencepiece_test.py               | 32 +++++++++++++++++++++++++
 4 files changed, 74 insertions(+), 22 deletions(-)

diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py
index 12dc631..ce9d60d 100644
--- a/python/src/sentencepiece/__init__.py
+++ b/python/src/sentencepiece/__init__.py
@@ -97,6 +97,13 @@ class ImmutableSentencePieceText_ImmutableSentencePiece(object):
               'begin: {}\n'
               'end: {}\n').format(self.piece, self.id, self.surface,
                                   self.begin, self.end)
+
+    def __eq__(self, other):
+      return self.piece == other.piece and self.id == other.id and self.surface == other.surface and self.begin == other.begin and self.end == other.end
+
+    def __hash__(self):
+      return hash(str(self))
+
     __repr__ = __str__
 
 
@@ -395,7 +402,7 @@ class SentencePieceProcessor(object):
              enable_sampling=False,
              nbest_size=-1,
              alpha=0.1,
-             num_threads=1):
+             num_threads=-1):
       """Initialzie sentencepieceProcessor.
 
       Args:
@@ -407,15 +414,15 @@ class SentencePieceProcessor(object):
           reversing (if enabled).
         reverse: Reverses the tokenized sequence (Default = false)
         emit_unk_piece: Emits the unk literal string (Default = false)
-        nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout.
+        nbest_size: sampling parameters for unigram. Invalid in BPE-Dropout.
                     nbest_size = {0,1}: No sampling is performed.
                     nbest_size > 1: samples from the nbest_size results.
                     nbest_size < 0: assuming that nbest_size is infinite and samples
                       from the all hypothesis (lattice) using
                       forward-filtering-and-backward-sampling algorithm.
         alpha: Soothing parameter for unigram sampling, and dropout probability of
-          merge operations for BPE-dropout.
-        num_threads: number of threads in batch processing.
+               merge operations for BPE-dropout.
+        num_threads: number of threads in batch processing (Default = -1, auto-detected)
       """
 
       _sentencepiece_processor_init_native(self)
@@ -450,18 +457,18 @@ class SentencePieceProcessor(object):
         out_type: output type. int or str.
         add_bos: Add <s> to the result (Default = false)
         add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
-          reversing (if enabled).
+                 reversing (if enabled).
         reverse: Reverses the tokenized sequence (Default = false)
         emit_unk_piece: Emits the unk literal string (Default = false)
-        nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout.
+        nbest_size: sampling parameters for unigram. Invalid in BPE-Dropout.
                     nbest_size = {0,1}: No sampling is performed.
                     nbest_size > 1: samples from the nbest_size results.
                     nbest_size < 0: assuming that nbest_size is infinite and samples
-                      from the all hypothesis (lattice) using
-                      forward-filtering-and-backward-sampling algorithm.
+                    from the all hypothesis (lattice) using
+                    forward-filtering-and-backward-sampling algorithm.
         alpha: Soothing parameter for unigram sampling, and merge probability for
                BPE-dropout (probablity 'p' in BPE-dropout paper).
-        num_threads: the number of threads used in the batch processin (Default = 1).
+        num_threads: the number of threads used in the batch processing (Default = -1).
       """
 
       if out_type is None:
@@ -722,7 +729,7 @@ class SentencePieceProcessor(object):
 
       Args:
         out_type: output type. str or 'serialized_proto' or 'immutable_proto' (Default = str)
-        num_threads: the number of threads used in the batch processin (Default = 1).
+        num_threads: the number of threads used in the batch processing (Default = -1).
       """
 
       if num_threads is None:
diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i
index 8309fc2..e22f763 100644
--- a/python/src/sentencepiece/sentencepiece.i
+++ b/python/src/sentencepiece/sentencepiece.i
@@ -233,9 +233,12 @@ class ThreadPool {
 
 template <typename T>
 inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
+  if (*num_threads < 0) {
+    *num_threads = std::thread::hardware_concurrency();
+  }
   *num_threads = std::max<int>(1,
                                std::min<int>({*num_threads,
-                                   static_cast<int>(ins.size()), 256}));
+                                     static_cast<int>(ins.size()), 256}));
 }
 
 #define DEFINE_ENCODE_BATCH_FUNC_IMPL(FuncName, InType, OutType)        \
@@ -675,7 +678,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
            enable_sampling=False,
            nbest_size=-1,
            alpha=0.1,
-           num_threads=1):
+           num_threads=-1):
     """Initialzie sentencepieceProcessor.
 
     Args:
@@ -687,15 +690,15 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
         reversing (if enabled).
       reverse: Reverses the tokenized sequence (Default = false)
       emit_unk_piece: Emits the unk literal string (Default = false)
-      nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout.
+      nbest_size: sampling parameters for unigram. Invalid in BPE-Dropout.
                   nbest_size = {0,1}: No sampling is performed.
                   nbest_size > 1: samples from the nbest_size results.
                   nbest_size < 0: assuming that nbest_size is infinite and samples
                     from the all hypothesis (lattice) using
                     forward-filtering-and-backward-sampling algorithm.
       alpha: Soothing parameter for unigram sampling, and dropout probability of
-        merge operations for BPE-dropout.
-      num_threads: number of threads in batch processing.
+             merge operations for BPE-dropout.
+      num_threads: number of threads in batch processing (Default = -1, auto-detected)
     """
 
     _sentencepiece_processor_init_native(self)
@@ -730,18 +733,18 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
       out_type: output type. int or str.
       add_bos: Add <s> to the result (Default = false)
       add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
-        reversing (if enabled).
+               reversing (if enabled).
       reverse: Reverses the tokenized sequence (Default = false)
       emit_unk_piece: Emits the unk literal string (Default = false)
-      nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout.
+      nbest_size: sampling parameters for unigram. Invalid in BPE-Dropout.
                   nbest_size = {0,1}: No sampling is performed.
                   nbest_size > 1: samples from the nbest_size results.
                   nbest_size < 0: assuming that nbest_size is infinite and samples
-                    from the all hypothesis (lattice) using
-                    forward-filtering-and-backward-sampling algorithm.
+                  from the all hypothesis (lattice) using
+                  forward-filtering-and-backward-sampling algorithm.
       alpha: Soothing parameter for unigram sampling, and merge probability for
              BPE-dropout (probablity 'p' in BPE-dropout paper).
-      num_threads: the number of threads used in the batch processin (Default = 1).
+      num_threads: the number of threads used in the batch processing (Default = -1).
     """
 
     if out_type is None:
@@ -1002,7 +1005,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
 
     Args:
       out_type: output type. str or 'serialized_proto' or 'immutable_proto' (Default = str)
-      num_threads: the number of threads used in the batch processin (Default = 1).
+      num_threads: the number of threads used in the batch processing (Default = -1).
     """
 
     if num_threads is None:
@@ -1260,6 +1263,13 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
               'begin: {}\n'
               'end: {}\n').format(self.piece, self.id, self.surface,
                                   self.begin, self.end)
+
+    def __eq__(self, other):
+      return self.piece == other.piece and self.id == other.id and self.surface == other.surface and self.begin == other.begin and self.end == other.end
+
+    def __hash__(self):
+      return hash(str(self))
+
     __repr__ = __str__
   %}
 }
diff --git a/python/src/sentencepiece/sentencepiece_wrap.cxx b/python/src/sentencepiece/sentencepiece_wrap.cxx
index 0a8df5f..1eac211 100644
--- a/python/src/sentencepiece/sentencepiece_wrap.cxx
+++ b/python/src/sentencepiece/sentencepiece_wrap.cxx
@@ -3042,9 +3042,12 @@ class ThreadPool {
 
 template <typename T>
 inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
+  if (*num_threads < 0) {
+    *num_threads = std::thread::hardware_concurrency();
+  }
   *num_threads = std::max<int>(1,
                                std::min<int>({*num_threads,
-                                   static_cast<int>(ins.size()), 256}));
+                                     static_cast<int>(ins.size()), 256}));
 }
 
 #define DEFINE_ENCODE_BATCH_FUNC_IMPL(FuncName, InType, OutType)        \
diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py
index ed792bd..6cbe077 100755
--- a/python/test/sentencepiece_test.py
+++ b/python/test/sentencepiece_test.py
@@ -332,6 +332,29 @@ class TestSentencepieceProcessor(unittest.TestCase):
     self.assertEqual(s4, y4)
     self.assertEqual(s5, y5)
 
+    hset_piece = defaultdict(int)
+
+    # eq test
+    for i in range(len(s1.pieces)):
+      self.assertEqual(s1.pieces[i], t1.pieces[i])
+      hset_piece[s1.pieces[i]] += 1
+      hset_piece[t1.pieces[i]] += 1
+
+    self.assertEqual(len(hset_piece), len(s1.pieces))
+
+    # has test
+    hset = defaultdict(int)
+    hset[s1] += 1
+    hset[t1] += 1
+    hset[s3] += 1
+    hset[t3] += 1
+
+    self.assertEqual(len(hset), 2)
+    self.assertEqual(hset[s1], 2)
+    self.assertEqual(hset[s3], 2)
+    self.assertEqual(hset[t1], 2)
+    self.assertEqual(hset[t3], 2)
+
     x1 = self.sp_.encode_as_serialized_proto(text)
     x2 = self.sp_.sample_encode_as_serialized_proto(text, 10, 0.2)
     x3 = self.sp_.nbest_encode_as_serialized_proto(text, 10)
@@ -363,6 +386,15 @@ class TestSentencepieceProcessor(unittest.TestCase):
       pieces.append(s1.pieces[i].piece)
     self.assertEqual(pieces, v2)
 
+    for v in s3.nbests:
+      self.assertEqual(text, v.text)
+      self.assertEqual(self.sp_.Decode([x.id for x in v.pieces]), text)
+
+    for i in range(len(s3.nbests)):
+      self.assertEqual(text, s3.nbests[i].text)
+      self.assertEqual(
+          self.sp_.Decode([x.id for x in s3.nbests[i].pieces]), text)
+
     # Japanese offset
     s1 = self.jasp_.EncodeAsImmutableProto('吾輩は猫である。Hello world. ABC 123')
     surfaces1 = [s1.text[x.begin:x.end] for x in s1.pieces]
