from copy import copy
from beliefs.cells import *
from belief_utils import choose
import numpy as np
import itertools
[docs]class BeliefState(DictCell):
"""
Represents a beliefstate, a partial information object *about* specific targets.
A beliefstate is a continuum between individual entities and types of entities.
Suppose the referential domain has entities: 1,2,3. A beliefstate represents the
possible groupings of these entities; a single grouping is called a *target*.
A beliefstate can be about "all targets of size two", for example, and computing
the beliefstate's extension would yield the targets {1,2}, {2,3}, and {1,3}.
In addition to containing a description of the intended targets, a belief state
contains meta-data about combinatoric constraints (such as arity size).
"""
def __init__(self, referential_domain=None):
"""
Initializes an empty beliefsstate, and stores the referential domain (if specified)
into self.referential_domain.
By default, beliefstates are given the 'S' part of speech and an empty
environment variable stack.
Most commonly, beliefstates are created by calling the `copy()` method
from a previous beliefstate.
"""
self.__dict__['pos'] = 'S' # syntactic state
self.__dict__['referential_domain'] = referential_domain
self.__dict__['environment_variables'] = {}
self.__dict__['deferred_effects'] = []
default_structure = {'target': DictCell(),
'distractor': DictCell(),
'targetset_arity': IntervalCell(0, np.inf),
'contrast_arity': IntervalCell(0, np.inf),
'is_in_commonground': BoolCell(),
'speaker_model': {'is_syntax_stacked': BoolCell(F)}}
DictCell.__init__(self, default_structure)
[docs] def set_pos(self, pos):
"""
Sets the beliefstates's part of speech, `pos`, and then executes
any deferred effects that are keyed by that pos tag.
:param pos: The part of speech of the beliefstate
:type pos: str
:returns: int,float -- The cost of the deferred effects associated with this part of speech
"""
self.__dict__['pos'] = pos
# if any deferred effects are keyed by this pos tag, evaluate them, and
# return their cumulative costs
return self.execute_deferred_effects(pos)
[docs] def get_pos(self):
"""
Returns Part of Speech
:returns: str -- The part of speech for this BeliefState. BeliefStates are initialized with a default pos of 'S'
"""
return self.__dict__['pos']
[docs] def add_deferred_effect(self, effect, pos):
"""
Pushes a (pos, effect) tuple onto a stack to later be executed if the
state reaches the 'pos'.
:param effect: A function that takes one argument and returns a number representing the cost associated with this pos
:type effect: function
:param pos: Part of Speech
:type pos: str
:raises: Exception, Contradiction
"""
if not isinstance(pos, (unicode, str)):
raise Exception("Invalid POS tag. Must be string not %d" % (type(pos)))
if self['speaker_model']['is_syntax_stacked'] == True:
self.__dict__['deferred_effects'].insert(0,(pos, effect,))
elif self['speaker_model']['is_syntax_stacked'] == False:
self.__dict__['deferred_effects'].append((pos, effect,))
else:
raise Contradiction("Speaker Model undefined")
[docs] def execute_deferred_effects(self, pos):
"""
Evaluates deferred effects that are triggered by the prefix of the
pos on the current beliefstate. For instance, if the effect is triggered
by the 'NN' pos, then the effect will be triggered by 'NN' or 'NNS'.
:param pos: A part of speech
:type pos: str
:returns: number -- Represents the cost of the deferred effects associated with *pos*. Can be int or float.
"""
costs = 0
to_delete = []
for entry in self.__dict__['deferred_effects']:
effect_pos, effect = entry
if pos.startswith(effect_pos):
costs += effect(self)
to_delete.append(entry)
# we delete afterwards, because Python cannot delete from a list that
# is being iterated over without screwing up the iteration.
for entry in to_delete:
self.__dict__['deferred_effects'].remove(entry)
return costs
[docs] def set_environment_variable(self, key, val):
"""
Sets a variable if that variable has not already been set
:param key: The key for the environment variable
:param val: The value that will be assigned to *key*
:raises: Contradiction -- raised if *key* has already been set to a value other than *val*
"""
if self.get_environment_variable(key) in [None, val]:
self.__dict__['environment_variables'][key] = val
else:
raise Contradiction("Could not set environment variable %s" % (key))
[docs] def get_environment_variable(self, key, default=None, pop=False):
"""
Returns the value associated with *key*.
:param key: The lookup key
:param default: This value is returned in the event that no value is associated with *key*
:param pop: Determines whether *key* is removed from the Environment Variables when it is found.
:type pop: bool
:returns: The value associated with *key*
"""
if key in self.__dict__['environment_variables']:
val = self.__dict__['environment_variables'][key]
if pop:
del self.__dict__['environment_variables'][key]
return val
else:
return default
[docs] def has_referential_domain(self):
"""
Returns ``True`` if the BeliefState has a context set defined -- meaning a set
of external referents.
:returns: bool -- Whether a contextset is defined
"""
return self.__dict__['referential_domain'] is not None
[docs] def iter_breadth_first(self, root=None):
"""
Traverses the belief state's structure breadth-first
:param root: Optional starting point for the search
:returns: Generator
"""
if root == None:
root = self
yield root
last = root
for node in self.iter_breadth_first(root):
if isinstance(node, DictCell):
# recurse
for subpart in node:
yield subpart
last = subpart
if last == node:
return
[docs] def find_path(self, test_function=None, on_targets=False):
"""
This general helper method iterates breadth-first over the elements in the
referential domain and yields all paths to the elements where the *test_function*
evaluates to ``True``
:param test_function: A function that takes two arguments and returns a boolean
:type test_function: function
:param on_targets:
:type on_targets: bool
:returns: Generator -- represents the path
"""
assert self.has_referential_domain(), "need context set"
if not test_function:
test_function = lambda x, y: True
def find_path_inner(part, prefix):
name, structure = part
if test_function(name, structure):
yield prefix + [name]
if isinstance(structure, DictCell):
for sub_structure in structure:
for prefix2 in find_path_inner(sub_structure,\
prefix[:] + [name]):
yield prefix2
prefix = []
if on_targets:
# apply search to the first target
results = []
for _, instance in self.iter_singleton_referents():
for part in instance:
for entry in find_path_inner(part, prefix[:]):
results.append(['target'] + entry)
while results:
yield results.pop()
break # only use first instance
else:
# apply search to self
for part in self:
for entry in find_path_inner(part, prefix[:]):
yield entry
[docs] def get_nth_unique_value(self, keypath, n, distance_from, open_interval=True):
"""
Returns the `n-1`th unique value, or raises
a contradiction if that is out of bounds
:param keypath:
:type keypath: list
:param n: An integer representing which unique value to return
:type n: int
:param reverse: Specifies the ordering of the values
:type reverse: bool
:returns: The 'n-1'th unique value
:raises: Contradiction
"""
unique_values = self.get_ordered_values(keypath, distance_from, open_interval)
if 0 <= n+1 < len(unique_values):
return unique_values[n]
else:
raise Contradiction("n-th Unique value out of range: " + str(n))
[docs] def get_ordered_values(self, keypath, distance_from, open_interval=True):
"""
Retrieves the referents's values sorted by their distance from the
min, max, or mid value.
:param keypath:
:param reverse: Specifies the ordering of the values
:type reverse: bool
:returns: OrderedDict -- Dictionary of sorted values
"""
logging.error("Looking for values from path"+str(keypath))
values = []
if keypath[0] == 'target':
# instances start with 'target' prefix, but
# don't contain it, so we remove it here.
keypath = keypath[1:]
for _, instance in self.iter_singleton_referents():
value = instance.get_value_from_path(keypath)
if hasattr(value, 'low') and value.low != value.high:
return []
values.append(float(value))
if len(values) == 0:
return []
values = np.array(values)
anchor = values.min()
diffs = values - anchor
if distance_from == 'max':
anchor = values.max()
diffs = anchor - values
if distance_from == 'mean':
anchor = values.mean()
diffs = abs(anchor - values)
sdiffs = np.unique(diffs)
sdiffs.sort()
results = []
for ix, el in enumerate(sdiffs):
mask = diffs <= el
vals = values[mask]
if False:
# when vagueness has been made precise through an ordinal
results.append(IntervalCell(vals.min(), vals.max()))
elif distance_from == 'max':
if open_interval:
results.append(IntervalCell(vals.min(), np.inf))
else:
results.append(IntervalCell(vals.min(), vals.min()))
elif distance_from == 'min':
if open_interval:
results.append(IntervalCell(-np.inf, vals.max()))
else:
results.append(IntervalCell(vals.max(), vals.max()))
elif distance_from == 'mean':
if ix+1 == len(sdiffs): continue # skip last
results.append(IntervalCell(vals.min(), vals.max()))
return results
[docs] def get_paths_for_attribute_set(self, keys):
"""
Given a list/set of keys (or one key), returns the parts that have
all of the keys in the list.
Because on_targets=True, this DOES NOT WORK WITH TOP LEVEL PROPERTIES,
only those of targets.
These paths are not pointers to the objects themselves, but tuples of
attribute names that allow us to (attempt) to look up that object in any
belief state.
:param keys: A collection of keys
:type keys: list,set
:returns: Generator
"""
if not isinstance(keys, (list, set)):
keys = [keys]
has_all_keys = lambda name, structure: \
all(map(lambda k: k in structure, keys))
return self.find_path(has_all_keys, on_targets=True)
[docs] def get_parts(self):
"""
Searches for all DictCells (with nested structure)
:returns: Generator
"""
return self.find_path(lambda x: isinstance(x[1], DictCell), on_targets=True)
[docs] def get_paths_for_attribute(self, attribute_name):
"""
Returns a path list to all attributes that have a particular name.
:param attribute_name:
:type attribute_name:
:returns: Generator
"""
has_name = lambda name, structure: name == attribute_name
return self.find_path(has_name, on_targets=True)
[docs] def merge(self, keypath, value, op='set'):
"""
First gets the cell at BeliefState's keypath, or creates a new cell
from the first element in the referential domain that has that keypath.
Second, this merges that cell with the value.
.. warning::
If two elements in the referential domain have the same named attributes (i.e. attribute paths)
but different Cells for the value, then the belief state will arbitrarily acquire the first
element in the referential domain and not entail the other because its cell will be incomparable.
:param keypath:
:type keypath: list
:param value:
:type value:
:param op:
:type op:
:returns: The result of calling the *op* method of the Cell at the end of *keypath* on *value*
:rasise: CellConstructionFailure
"""
negated = False
keypath = keypath[:] # copy it
if keypath[0] == 'target':
# only pull negated if it can potentially modify target
negated = self.get_environment_variable('negated', pop=True, default=False)
if negated:
keypath[0] = "distractor"
if keypath not in self:
first_referent = None
if keypath[0] in ['target', 'distractor']:
has_targets = False
for _, referent in self.iter_singleton_referents():
has_targets = True
if keypath[1:] in referent:
first_referent = referent
break
if first_referent is None:
# this happens when none of the available targets have the
# path that is attempted to being merged to
if has_targets:
raise CellConstructionFailure("Cannot merge; no target: %s" \
% (str(keypath)))
else:
# this will always happen when size is 0
raise CellConstructionFailure("Empty belief state")
# find the type and add it to the
cell = first_referent.get_value_from_path(keypath[1:]).stem()
self.add_cell(keypath, cell)
else:
# should we allow merging undefined components outside of target?
raise Exception("Could not find Keypath %s" % (str(keypath)))
# break down keypaths into
cell = self
if not isinstance(keypath, list):
keypath = [keypath]
for key in keypath:
cell = cell[key]
# perform operation (set, <=, >= etc)
try:
return getattr(cell, op)(value)
except Contradiction as ctrd:
# add more information to the contradiction
raise Contradiction("Could not merge %s with %s: %s " % (str(keypath), str(value), ctrd))
[docs] def add_cell(self, keypath, cell):
"""
Adds a new cell to the end of `keypath` of type `cell`
:param keypath: A sequence of keys that specify a path through nested dictionaries
:type keypath: list
:param cell: The cell to be added to the end of keypath
:returns: The *cell* object that was passed in as a parameter
"""
keypath = keypath[:] # copy
inner = self # the most inner dict where cell is added
cellname = keypath # the name of the cell
assert keypath not in self, "Already exists: %s " % (str(keypath))
if isinstance(keypath, list):
while len(keypath) > 1:
cellname = keypath.pop(0)
if cellname not in inner:
inner.__dict__['p'][cellname] = DictCell()
inner = inner[cellname] # move in one
cellname = keypath[0]
# now we can add 'cellname'->(Cell) to inner (DictCell)
inner.__dict__['p'][cellname] = cell
return inner[cellname]
[docs] def entails(self, other):
"""
One beliefstate entails another beliefstate iff the other state's entities are
all equal or more general than the caller's parts. That means the other
state must at least have all of the same keys/components.
.. note::
this only compares the items in the DictCell, not `pos`,`environment_variables` or `deferred_effects`.
:param other: The BeliefState to compare with
:type other: BeliefState
:returns: bool
:raises: Exception
"""
return other.is_entailed_by(self)
[docs] def is_entailed_by(self, other):
"""
Given two beliefstates, returns True iff the calling instance
implies the other beliefstate, meaning it contains at least the same
structure (for all structures) and all values (for all defined values).
Inverse of `entails`.
.. note::
this only compares the items in the DictCell, not `pos`, `environment_variables` or `deferred_effects`.
:param other: BeliefState to compare with
:type other: BeliefState
:returns: bool
:raises: Exception
"""
for (s_key, s_val) in self:
if s_key in other:
if not hasattr(other[s_key], 'implies'):
raise Exception("Cell for %s is missing implies()" % s_key)
if not other[s_key].implies(s_val):
return False
else:
return False
return True
[docs] def is_equal(self, other):
"""
Two beliefstates are equal if all of their part names are equal and all
of their cell's values return True for is_equal().
.. note::
this only compares the items in the DictCell, not `pos`, `environment_variables` or `deferred_effects`.
:param other: BeliefState to compare with
:type other: BeliefState
:returns: bool
"""
return hash(self) == hash(other)
for (this, that) in itertools.izip_longest(self, other):
if that[0] is None or this[0] != that[0]:
# compare attribute names
return False
if not this[1].is_equal(that[1]):
# compare values
return False
return True
[docs] def is_contradictory(self, other):
""" Two beliefstates are incompatible if the other beliefstates's entities
are not consistent with or accessible from the caller.
.. note::
this only compares the items in the DictCell, not `pos`, `environment_variables` or `deferred_effects`.
:param other: BeliefState to compare with
:type other: BeliefState
:returns: bool
"""
for (s_key, s_val) in self:
if s_key in other and s_val.is_contradictory(other[s_key]):
return True
return False
[docs] def size(self):
""" Returns the size of the belief state. This is the number of referents it implicitly represents.
Initially if there are :math:`n` consistent members, (the result of :meth:`number_of_singleton_referents`)
then there are generally :math:`2^{n}-1` valid referents.
:returns: int -- Size of the contextset
:raises: Exception
"""
n = self.number_of_singleton_referents()
targets = list(self.iter_referents_tuples())
n_targets = len(targets)
if n == 0 or n_targets == 0:
return 0
#if len(self.__dict__['deferred_effects']) != 0:
# return -1
size1 = len(list(self.iter_referents_tuples()))
tlow, thigh = self['targetset_arity'].get_tuple()
clow, chigh = self['contrast_arity'].get_tuple()
size2 = binomial_range(n, max(tlow,1), min([n-max(clow,0),thigh,n]))
#assert size1 == size2, "%i != %i" % (size1, size2)
return size2
[docs] def referents(self):
"""
Returns all target sets that are compatible with the current beliefstate.
.. warning::
the number of referents can be quadratic in the number of elements in the referential domain.
Call `size()` method instead to compute size only, without enumerating them.
:returns: list -- Members of contextset that are compatible with beliefstate
"""
# all groupings of singletons
return list(self.iter_referents())
[docs] def iter_referents(self):
""" Generates target sets that are compatible with the current beliefstate.
:returns: Generator
"""
tlow, thigh = self['targetset_arity'].get_tuple()
clow, chigh = self['contrast_arity'].get_tuple()
referents = list(self.iter_singleton_referents())
t = len(referents)
low = max(1, tlow)
high = min([t, thigh])
for targets in itertools.chain.from_iterable(itertools.combinations(referents, r) \
for r in reversed(xrange(low, high+1))):
if clow <= t-len(targets) <= chigh:
yield targets
[docs] def iter_referents_tuples(self):
""" Generates target sets (as tuples of indicies) that are compatible with
the current beliefstate.
:returns: Generator
"""
tlow, thigh = self['targetset_arity'].get_tuple()
clow, chigh = self['contrast_arity'].get_tuple()
singletons = list([int(i) for i,_ in self.iter_singleton_referents()])
t = len(singletons)
low = max(1, tlow)
high = min([t, thigh])
for elements in itertools.chain.from_iterable(itertools.combinations(singletons, r) \
for r in reversed(xrange(low, high+1))):
if clow <= t-len(elements) <= chigh:
yield elements
[docs] def number_of_singleton_referents(self):
"""
Returns the number of singleton elements of the referential domain that are
compatible with the current belief state.
This is the size of the union of all referent sets.
:returns: int -- The number of singleton members of the contextset
:raises: Exception -- Raised when no contextset is defined for the BeliefState
"""
if self.__dict__['referential_domain']:
ct = 0
for i in self.iter_singleton_referents():
ct += 1
return ct
else:
raise Exception("self.referential_domain must be defined")
[docs] def iter_singleton_referents(self):
"""
Iterator of all of the singleton members of the context set.
:returns: Generator
:raises: Exception
"""
try:
for member in self.__dict__['referential_domain'].iter_entities():
if self['target'].is_entailed_by(member) and (self['distractor'].empty() or not self['distractor'].is_entailed_by(member)):
yield member['num'], member
except KeyError:
raise Exception("No referential_domain defined")
[docs] def to_latex(self, number=0):
""" Returns a raw text string that contains a latex representation of
the belief state as an attribute-value matrix. This requires:
\usepackage{avm}
"""
latex = r"""\avmfont{\sc}
\avmoptions{sorted,active}
\avmvalfont{\rm}"""
latex += "\n\nb_%i = \\begin{avm} \n " % number
latex += DictCell.to_latex(self)
latex += "\n\\end{avm}\n"
return latex
[docs] def copy(self):
"""
Copies the BeliefState by recursively deep-copying all of
its parts. Domains are not copied, as they do not change
during the interpretation or generation.
:returns: BeliefState
"""
copied = BeliefState(self.__dict__['referential_domain'])
for key in ['environment_variables', 'deferred_effects', 'pos', 'p']:
copied.__dict__[key] = copy.deepcopy(self.__dict__[key])
return copied
def __hash__(self):
"""
This is the all-important hash method that recursively computes a hash
value from the components of the beliefstate. The search process treats
two beliefstates as equal if their hash values are the same.
"""
hashval = 0
# hash part of speech
hashval += hash(self.__dict__['pos'])
# hash environment variables
for ekey, kval in self.__dict__['environment_variables'].items():
hashval += hash(ekey) + hash(kval)
for effect in self.__dict__['deferred_effects']:
hashval += hash(effect)
# hash dictionary
for i, (key, value) in enumerate(self.__dict__['p'].items()):
hashval += hash(value) * hash(key)
# -2 is a reserved value
if hashval == -2:
hashval = -1
return hashval
__eq__ = is_equal
if __name__ == '__main__':
from models import *
from models.online import *
c = ContextSet.get_by_name("Amazon Kindles")
b = BeliefState(c)
print b.to_latex()