Viewing file: tensor_store.py (8.07 KB) -rw-r--r-- Select action/file-type: (+) | (+) | (+) | Code (+) | Session (+) | (+) | SDB (+) | (+) | (+) | (+) | (+) | (+) |
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ==============================================================================
from __future__ import absolute_import from __future__ import division from __future__ import print_function
import numpy as np import tensorflow as tf from tensorflow.python.debug.lib import debug_data
from tensorboard.plugins.debugger import tensor_helper
# TODO(cais): Defer some tensor values to TensorBase, within a bytes limit, # before discarding them.
class _TensorValueDiscarded(object):
def __init__(self, watch_key, time_index): self._watch_key = watch_key self._time_index = time_index
@property def watch_key(self): return self._watch_key
@property def time_index(self): return self._time_index
@property def nbytes(self): return 0
class _WatchStore(object): """The store for a single debug tensor watch.
Discards data according to pre-set byte limit. """
def __init__(self, watch_key, mem_bytes_limit=10e6): """Constructor of _WatchStore.
The overflowing works as follows: The most recent tensor values are stored in memory, up to `mem_bytes_limit` bytes. But at least one (the most recent) value is always stored in memory. For older tensors exceeding that limit, they are discarded.
Args: watch_key: A string representing the debugger tensor watch, with th format: <NODE_NAME>:<OUTPUT_SLOT>:<DEBUG_OP_NAME> e.g., 'Dense_1/BiasAdd:0:DebugIdentity'. mem_bytes_limit: Limit on number of bytes to store in memory. """
self._watch_key = watch_key self._mem_bytes_limit = mem_bytes_limit self._in_mem_bytes = 0 self._disposed = False self._data = [] # A map from index to tensor value.
def add(self, value): """Add a tensor the watch store.""" if self._disposed: raise ValueError( 'Cannot add value: this _WatchStore instance is already disposed') self._data.append(value) if hasattr(value, 'nbytes'): self._in_mem_bytes += value.nbytes self._ensure_bytes_limits()
def _ensure_bytes_limits(self): # TODO(cais): Thread safety? if self._in_mem_bytes <= self._mem_bytes_limit: return
i = len(self._data) - 1 cum_mem_size = 0 while i >= 0: if hasattr(self._data[i], 'nbytes'): cum_mem_size += self._data[i].nbytes if i < len(self._data) - 1 and cum_mem_size > self._mem_bytes_limit: # Always keep at least one time index in the memory. break i -= 1 # i is now the last time index to discard.
# Mark remaining ones as discarded. while i >= 0: if not isinstance(self._data[i], _TensorValueDiscarded): self._data[i] = _TensorValueDiscarded(self._watch_key, i) i -= 1
def num_total(self): """Get the total number of values.""" return len(self._data)
def num_in_memory(self): """Get number of values in memory.""" n = len(self._data) - 1 while n >= 0: if isinstance(self._data[n], _TensorValueDiscarded): break n -= 1 return len(self._data) - 1 - n
def num_discarded(self): """Get the number of values discarded due to exceeding both limits.""" if not self._data: return 0 n = 0 while n < len(self._data): if not isinstance(self._data[n], _TensorValueDiscarded): break n += 1 return n
def query(self, time_indices): """Query the values at given time indices.
Args: time_indices: 0-based time indices to query, as a `list` of `int`.
Returns: Values as a list of `numpy.ndarray` (for time indices in memory) or `None` (for time indices discarded). """ if self._disposed: raise ValueError( 'Cannot query: this _WatchStore instance is already disposed') if not isinstance(time_indices, (tuple, list)): time_indices = [time_indices] output = [] for time_index in time_indices: if isinstance(self._data[time_index], _TensorValueDiscarded): output.append(None) else: data_item = self._data[time_index] if (hasattr(data_item, 'dtype') and tensor_helper.translate_dtype(data_item.dtype) == 'string'): _, _, data_item = tensor_helper.array_view(data_item) data_item = np.array( tensor_helper.process_buffers_for_display(data_item), dtype=np.object) output.append(data_item)
return output
def dispose(self): self._disposed = True
class TensorStore(object):
def __init__(self, watch_mem_bytes_limit=10e6): """Constructor of TensorStore.
Args: watch_mem_bytes_limit: Limit on number of bytes to store in memory for each watch key. """ self._watch_mem_bytes_limit = watch_mem_bytes_limit self._tensor_data = dict() # A map from watch key to _WatchStore instances.
def add(self, watch_key, tensor_value): """Add a tensor value.
Args: watch_key: A string representing the debugger tensor watch, e.g., 'Dense_1/BiasAdd:0:DebugIdentity'. tensor_value: The value of the tensor as a numpy.ndarray. """ if watch_key not in self._tensor_data: self._tensor_data[watch_key] = _WatchStore( watch_key, mem_bytes_limit=self._watch_mem_bytes_limit) self._tensor_data[watch_key].add(tensor_value)
def query(self, watch_key, time_indices=None, slicing=None, mapping=None): """Query tensor store for a given watch_key.
Args: watch_key: The watch key to query. time_indices: A numpy-style slicing string for time indices. E.g., `-1`, `:-2`, `[::2]`. If not provided (`None`), will use -1. slicing: A numpy-style slicing string for individual time steps. mapping: An mapping string or a list of them. Supported mappings: `{None, 'image/png', 'health-pill'}`.
Returns: The potentially sliced values as a nested list of values or its mapped format. A `list` of nested `list` of values.
Raises: ValueError: If the shape of the sliced array is incompatible with mapping mode. Or if the mapping type is invalid. """ if watch_key not in self._tensor_data: raise KeyError("watch_key not found: %s" % watch_key)
if time_indices is None: time_indices = '-1' time_slicing = tensor_helper.parse_time_indices(time_indices) all_time_indices = list(range(self._tensor_data[watch_key].num_total())) sliced_time_indices = all_time_indices[time_slicing] if not isinstance(sliced_time_indices, list): sliced_time_indices = [sliced_time_indices]
recombine_and_map = False step_mapping = mapping if len(sliced_time_indices) > 1 and mapping not in (None, ): recombine_and_map = True step_mapping = None
output = [] for index in sliced_time_indices: value = self._tensor_data[watch_key].query(index)[0] if (value is not None and not isinstance(value, debug_data.InconvertibleTensorProto)): output.append(tensor_helper.array_view( value, slicing=slicing, mapping=step_mapping)[2]) else: output.append(None)
if recombine_and_map: if mapping == 'image/png': output = tensor_helper.array_to_base64_png(output) elif mapping and mapping != 'none': tf.logging.warn( 'Unsupported mapping mode after recomining time steps: %s', mapping) return output
def dispose(self): for watch_key in self._tensor_data: self._tensor_data[watch_key].dispose()
|