#
# Copyright (C) 2018 Neal Digre.
#
# This software may be modified and distributed under the terms
# of the MIT license. See the LICENSE file for details.
"""Language-specific and unicode utilities.
Todo:
* Variable UNK token in Vocabulary
"""
import abc
import os
[docs]def code2hex(unicode):
"""Returns hex integer for a unicode string."""
if 'U+' in unicode:
unicode = unicode.lstrip('U+')
return int(unicode, 16)
[docs]def code2char(unicode):
"""Returns the unicode string for the character."""
return chr(code2hex(unicode))
[docs]class CharacterSet(object):
"""Character set abstract class."""
__metaclass__ = abc.ABCMeta
def __init__(self, charset):
"""Initializer
Args:
charset (str): ID for types of characters to include.
"""
self._ranges = self._unicode_ranges(charset)
@abc.abstractmethod
def _unicode_ranges(self, charset):
"""Returns appropriate unicode ranges for specified ``charset``.
Args:
charset (str): ID of character set to use.
Returns:
:obj:`list` of :obj:`tuple`: Unicode ranges [(low, high)]
"""
[docs] def in_charset(self, unicode):
"""Check if a character is in the defined character set.
Args:
unicode (str): String representation of unicode value.
"""
hexcode = code2hex(unicode)
if any([r[0] <= hexcode <= r[1] for r in self._ranges]):
return True
else:
return False
[docs]class JapaneseUnicodes(CharacterSet):
"""Utility for accessing and manipulating Japanese character
unicodes.
Inherits from :obj:`CharacterSet`.
Unicode ranges taken from [1] with edits for exceptions.
References:
[1] http://www.unicode.org/charts/
"""
PUNCTUATION = [
(int('25a0', 16), int('25ff', 16)), # square
(int('25b2', 16), int('25b3', 16)), # triangle
(int('25cb', 16), int('25cf', 16)), # circle
(int('25ef', 16), int('25ef', 16)), # big circle
(int('3200', 16), int('32ff', 16)), # filled big circles
(int('3000', 16), int('303f', 16)), # CJK symbols, punctuation
(int('3099', 16), int('309e', 16)), # voicing, iteration marks
(int('30a0', 16), int('30a0', 16)), # double hyphen
(int('30fb', 16), int('30fe', 16)), # dot, prolonged, iteration
(int('ff5b', 16), int('ff64', 16)), # brackets, halfwidth punctuation
(int('ffed', 16), int('ffee', 16)) # halfwidth square, circle
]
HIRAGANA = [
(int('3040', 16), int('3096', 16)),
(int('309f', 16), int('309f', 16)) # より
]
KATAKANA = [
(int('30a1', 16), int('30fa', 16)),
(int('30ff', 16), int('30ff', 16)), # コト
(int('ff65', 16), int('ff9d', 16)) # halfwidth
]
KANA = HIRAGANA + KATAKANA
MISC = [
(int('0030', 16), int('0039', 16)), # digits
(int('ff00', 16), int('ff5a', 16)), # roman characters
(int('ffa0', 16), int('ffdc', 16)), # hangul characters
(int('ffe0', 16), int('ffec', 16)), # symbols
# (int('003f', 16), int('003f', 16)), # question mark
]
# Kanji covers full CJK set and extensions
KANJI = [
(int('3400', 16), int('4db5', 16)),
(int('4e00', 16), int('9fea', 16)),
(int('f900', 16), int('fad9', 16)),
(int('20000', 16), int('2ebe0', 16)),
]
ALL = HIRAGANA + KATAKANA + KANJI + PUNCTUATION + MISC
def __init__(self, charset):
super(JapaneseUnicodes, self).__init__(charset)
def _unicode_ranges(self, charset):
if charset == 'all':
ranges = JapaneseUnicodes.ALL
else:
ranges = []
if 'hiragana' in charset:
ranges += JapaneseUnicodes.HIRAGANA
elif 'katakana' in charset:
ranges += JapaneseUnicodes.KATAKANA
elif 'kana' in charset:
ranges += JapaneseUnicodes.KANA
if 'kanji' in charset:
ranges += JapaneseUnicodes.KANJI
if 'punct' in charset:
ranges += JapaneseUnicodes.PUNCTUATION
if 'misc' in charset:
ranges += JapaneseUnicodes.MISC
assert len(ranges) > 0, "Invalid character set."
return ranges
[docs]class Vocabulary(object):
"""Simple vocabulary wrapper.
References:
Lightly modified TensorFlow "im2txt" `Vocabulary`_.
.. _Vocabulary: https://github.com/tensorflow/models/blob/master/
research/im2txt/im2txt/data/build_mscoco_data.py
"""
UNK = "<UNK>"
def __init__(self, reserved, vocab):
"""Initializes the vocabulary.
Args:
reserved (tuple): Tuple of reserved tokens.
vocab: (list): List of vocabulary entries, ideally (for
visualization) in descending order by frequency.
"""
self._vocab = {}
for ix, char in enumerate(vocab):
self._vocab[char] = ix
add2id = 0
for i in range(len(reserved)):
if i in self._vocab.values():
add2id += 1
self._vocab = {key: idx + add2id for key, idx in self._vocab.items()}
for i, char in enumerate(reserved):
self._vocab[char] = i
try:
self._unk_id = reserved.index(self.UNK)
except ValueError:
print("'{}' token not provided. Setting to highest ID.".format(
self.UNK
))
self._vocab[self.UNK] = len(self._vocab)
self._rev_vocab = {idx: key for key, idx in self._vocab.items()}
def save(self, out_dir):
vocab_sorted = [self._rev_vocab[idx]
for idx in sorted(self._rev_vocab.keys())]
with open(os.path.join(out_dir, 'vocab.txt'), 'w') as f:
for token in vocab_sorted:
f.write(token + '\n')
[docs] def char_to_id(self, char):
"""Returns the integer id of a character string."""
if char in self._vocab:
return self._vocab[char]
else:
return self._unk_id
[docs] def id_to_char(self, char_id):
"""Returns the character string of a integer id."""
if char_id in self._rev_vocab:
return self._rev_vocab[char_id]
else:
return self.UNK
[docs] def get_num_classes(self):
"""Returns number of classes, includes <UNK>."""
return len(self._vocab)