Source code for carpedm.data.ops

#
# Copyright (c) 2018 Neal Digre.
#
# This software may be modified and distributed under the terms
# of the MIT license. See the LICENSE file for details.


"""Data operations.

This module contains several non-module-specific data operations.

Todo:
    * Tests
        * ``to_sequence_example``, ``parse_sequence_example``
        * ``sparsify_label``
        * ``shard_batch``
        * ``same_line``
        * ``ixs_in_region``
        * ``seq_norm_bbox_values``
"""
import tensorflow as tf


def _int64_feature(value):
    """Wrapper for inserting int64 features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _float_feature(value):
    """Wrapper for inserting float features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def _bytes_feature(value):
    """Wrapper for inserting a bytes features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def _int64_feature_list(values):
    """Wrapper for inserting int64 FeatureList into Example proto."""
    return tf.train.FeatureList(feature=[_int64_feature(v) for v in values])


def _float_feature_list(values):
    """Wrapper for inserting float Feature list into Example proto."""
    return tf.train.FeatureList(feature=[_float_feature(v) for v in values])


def _bytes_feature_list(values):
    """Wrapper for inserting bytes FeatureList into Example proto."""
    return tf.train.FeatureList(feature=[_bytes_feature(v) for v in values])


[docs]def to_sequence_example(feature_dict): """Convert features to TensorFlow SequenceExample. Args: feature_dict (dict): Dictionary of features. Returns: :obj:`tf.train.SequenceExample` """ feature = dict() feature_list = dict() for key, value in feature_dict.items(): if 'seq' in key: if 'bbox' in key: feature_list[key] = _float_feature_list(value) else: feature_list[key] = _int64_feature_list(value) elif 'data' in key: image_raw = value.tostring() feature[key] = _bytes_feature(image_raw) else: feature[key] = _int64_feature(value) context = tf.train.Features(feature=feature) feature_lists = tf.train.FeatureLists(feature_list=feature_list) example = tf.train.SequenceExample(context=context, feature_lists=feature_lists) return example
[docs]def parse_sequence_example(serialized): """Parse a sequence example. Args: serialized (:obj:`tf.Tensor`): Serialized 0-D tensor of type string. Returns: dict: Dictionary of features. """ context, sequence = tf.parse_single_sequence_example( serialized, context_features={ "image/data": tf.FixedLenFeature([], tf.string), "image/height": tf.FixedLenFeature([], tf.int64), "image/width": tf.FixedLenFeature([], tf.int64), "image/char/count": tf.FixedLenFeature([], tf.int64, 0), "image/line/count": tf.FixedLenFeature([], tf.int64, 0), }, sequence_features={ "image/seq/char/id": tf.FixedLenSequenceFeature( [], tf.int64, allow_missing=True), "image/seq/char/bbox/xmin": tf.FixedLenSequenceFeature( [], tf.float32, allow_missing=True), "image/seq/char/bbox/ymin": tf.FixedLenSequenceFeature( [], tf.float32, allow_missing=True), "image/seq/char/bbox/xmax": tf.FixedLenSequenceFeature( [], tf.float32, allow_missing=True), "image/seq/char/bbox/ymax": tf.FixedLenSequenceFeature( [], tf.float32, allow_missing=True), "image/seq/line/bbox/xmin": tf.FixedLenSequenceFeature( [], tf.float32, allow_missing=True), "image/seq/line/bbox/ymin": tf.FixedLenSequenceFeature( [], tf.float32, allow_missing=True), "image/seq/line/bbox/xmax": tf.FixedLenSequenceFeature( [], tf.float32, allow_missing=True), "image/seq/line/bbox/ymax": tf.FixedLenSequenceFeature( [], tf.float32, allow_missing=True), } ) feature_dict = {} height = tf.cast(context['image/height'], tf.int32) width = tf.cast(context['image/width'], tf.int32) for key, value in context.items(): if 'data' in key: image = tf.decode_raw(value, tf.uint8) feature_dict[key] = tf.reshape(image, [height, width, 3]) else: feature_dict[key] = tf.cast(value, tf.int32) for obj in ['char', 'line']: values = [] for value in ['ymin', 'xmin', 'ymax', 'xmax']: key = "image/seq/{}/bbox/{}".format(obj, value) values.append(tf.reshape(sequence.pop(key), [-1, 1])) bboxes = tf.concat(axis=1, values=values) feature_dict['image/seq/' + obj + '/bbox'] = bboxes for key, value in sequence.items(): feature_dict[key] = tf.cast(value, tf.int32) return feature_dict
[docs]def sparsify_label(label, length): """Convert a regular Tensor into a SparseTensor. Args: label (:obj:`tf.Tensor`): The label to convert. length (:obj:`tf.Tensor`): Length of the label Returns: tf.SparseTensor """ length = tf.cast(length, dtype=tf.int64) indices = tf.where(tf.not_equal(label, 0)) values = tf.cast(tf.gather_nd(label, indices), tf.int32) char_ids = tf.SparseTensor(indices, values, dense_shape=[length]) return char_ids
[docs]def shard_batch(features, labels, batch_size, num_shards): """Shard a batch of examples. Args: features (dict): Dictionary of features. labels (:obj:`tf.Tensor`): labels batch_size (int): The batch size. num_shards (int): Number of shards into which batch is split. Returns: :obj:`list` of :obj:`dict`: Features as a list of dictionaries. """ label_batch = tf.unstack(labels, num=batch_size, axis=0) label_shards = [[] for i in range(num_shards)] for i in range(batch_size): idx = i % num_shards label_shards[idx].append(label_batch[i]) feature_shards = [{} for i in range(num_shards)] for key in features: feature_batch = tf.unstack(features[key], num=batch_size, axis=0) shards = [[] for i in range(num_shards)] for i in range(batch_size): idx = i % num_shards shards[idx].append(feature_batch[i]) for i in range(num_shards): feature_shards[i][key] = tf.stack(shards[i]) return feature_shards, label_shards
[docs]def in_line(xmin_line, xmax_line, ymin_line, xmin_new, xmax_new, ymax_new): """Heuristic for determining whether a character is in a line. Note: Currently dependent on the order in which characters are added. For example, a character may vertically overlap with a line, but adding it to the line would be out of reading order. This should be fixed in a future version. Args: xmin_line (:obj:`list` of :obj:`int`): Minimum x-coordinate of characters in the line the new character is tested against. xmax_line (:obj:`list` of :obj:`int`): Maximum x-coordinate of characters in the line the new character is tested against. ymin_line (int): Minimum y-coordinate of line the new character is tested against. xmin_new (int): Minimum x-coordinate of new character. xmax_new (int): Maximum x-coordinate of new character. ymax_new (int): Maximum y-coordinate of new character. Returns: bool: The new character vertically overlaps with the "average" character in the line. """ xmin_avg = sum(xmin_line) / len(xmin_line) xmax_avg = sum(xmax_line) / len(xmax_line) return (xmin_avg <= xmax_new and xmax_avg >= xmin_new and ymax_new >= ymin_line)
[docs]def in_region(obj, region, entire=True): """Test if an object is in a region. Args: obj (tuple or BBox): Object bounding box (xmin, xmax, ymin, ymax) or point (x, y). region (tuple or BBox): Region (xmin, xmax, ymin, ymax). entire (bool): Object is entirely contained in region. Returns: bool: Result """ if len(obj) == 4: if entire: result = (region[0] <= obj[0] <= obj[1] <= region[1] and region[2] <= obj[2] <= obj[3] <= region[3]) else: result = (region[0] <= obj[0] <= region[1] or region[0] <= obj[1] <= region[1] or region[2] <= obj[2] <= region[3] or region[2] <= obj[3] <= region[3]) else: assert len(obj) == 2, "Invalid point or bounding box." result = (region[0] <= obj[0] <= region[1] and region[2] <= obj[1] <= region[3]) return result
[docs]def ixs_in_region(bboxes, y1, y2, x1, x2): """Heuristic for determining objects in a region. Args: bboxes (:obj:`list` of :obj:`carpedm.data.util.BBox`): Bounding boxes for object boundaries. y1 (int): Top (lowest row index) of region. y2 (int): Bottom (highest row index) of region. x1 (int): left side (lowest column index) of region. x2 (int): right side (highest column index) of region. Returns: :obj:`list` of :obj:`int`: Indices of objects inside region. """ result = [] for i in range(len(bboxes)): b = bboxes[i] if (b.xmin >= x1 and b.xmax <= x2 and y1 <= b.ymin and b.ymax <= y2): result.append(i) return result
[docs]def seq_norm_bbox_values(bboxes, height, width): """Sequence and normalize bounding box values. Args: bboxes (:obj:`list` of :obj:`carpedm.data.util.BBox`): Bounding boxes to process. width (int): Width (in pixels) of image bboxes are in. height (int): Height (in pixels) of image bboxes are in. Returns: tuple: :obj:`tuple` containing: :obj:`list` of :obj:`float`: Normalized minimum x-values :obj:`list` of :obj:`float`: Normalized minimum y-values :obj:`list` of :obj:`float`: Normalized maximum x-values :obj:`list` of :obj:`float`: Normalized maximum y-values """ xmin, ymin, xmax, ymax = [], [], [], [] for b in bboxes: xmin.append(b.xmin / width) ymin.append(b.ymin / height) xmax.append(b.xmax / width) ymax.append(b.ymax / height) return xmin, ymin, xmax, ymax