#include "psearch_infer.h"
#include "psearch_session.h"
namespace psearch {
/////////////////////////////////////////////////////////////////////////////////////////
// PSearchInfer::infer
//
/////////////////////////////////////////////////////////////////////////////////////////
psresult_ptr_type
PSearchInfer::infer(bool verbose)
{
infer_execution_state query_state;
psc_queue_type & psc_queue = query_state.psc_queue;
pattern_set_type & visited_pattern_set = query_state.visited_pattern_set;
psc_list_type & negated_psc_list = query_state.negated_psc_list;
if(verbose) {
printQuery();
}
// initialize query state
initialize_query_state(query_state);
select_relevant_patterns(query_state);
unsigned int iloop = 0;
while(iloop++ < m_max_loop and not psc_queue.empty()) {
bool skip_pattern = false;
do {
// take the pattern of highest weight that passes all selection criteria
psc_ptr_type psc_p = psc_queue.top();
pattern_index_type pattern_index = psc_p->get_pattern_index();
psc_queue.pop();
// check if c1 criteria is meet
bool is_c1 = is_c1_match(pattern_index);
psc_p->set_c1_match(is_c1);
skip_pattern = not is_c1;
if(skip_pattern) {
visited_pattern_set.erase(pattern_index);
continue;
}
// check if pattern is negated by has not nodes
skip_pattern = is_negated_by_has_not_nodes(pattern_index);
if(skip_pattern) {
visited_pattern_set.erase(pattern_index);
continue;
}
// check if pattern is negated by has nodes
skip_pattern = is_negated_by_has_nodes(query_state, psc_p);
if(skip_pattern) {
visited_pattern_set.erase(pattern_index);
continue;
}
if(verbose) std::cout << "Infer Pattern: " << pattern_index << std::endl;
// infer pattern's c2 nodes
infer_pattern(query_state, psc_p);
} while(skip_pattern and not psc_queue.empty());
// select relevant patterns based on the inferred_node_set
select_relevant_patterns(query_state);
// process all negated patterns
while(not negated_psc_list.empty()) {
psc_ptr_type psc_p = negated_psc_list.front();
negated_psc_list.pop_front();
if(verbose) std::cout << "Retract Pattern: " << psc_p->get_pattern_index() << std::endl;
retract_pattern(query_state, psc_p);
}
}
return m_result_p;
}
/////////////////////////////////////////////////////////////////////////////////////////
// PSearchInfer::initialize_query_state
//
/////////////////////////////////////////////////////////////////////////////////////////
void
PSearchInfer::initialize_query_state(infer_execution_state & query_state)
{
node_set_type & inferred_node_set = query_state.inferred_node_set;
node_dual_set_const_iterator itor = m_psession_p->get_nodes_iterator();
while(not itor.is_end()) {
node_index_type node_index = itor.get_value_by_value();
bool is_asserted = m_psession_p->is_asserted_has_node(node_index);
nsc_ptr_type nsc_p = m_result_p->add_nsc(node_index);
nsc_p->set_has_node(true);
nsc_p->set_asserted_node(is_asserted);
nsc_p->add_weight(10);
inferred_node_set.insert(node_index);
itor.next();
}
node_const_iterator_type nitor = m_psession_p->get_negated_nodes_begin();
node_const_iterator_type nend = m_psession_p->get_negated_nodes_end();
for(; nitor!=nend; ++nitor) {
node_index_type node_index = *nitor;
nsc_ptr_type nsc_p = m_result_p->add_nsc(node_index);
nsc_p->set_has_node(false);
nsc_p->set_asserted_node(true);
nsc_p->add_weight(10);
}
};
/////////////////////////////////////////////////////////////////////////////////////////
// PSearchInfer::is_c1_match
//
/////////////////////////////////////////////////////////////////////////////////////////
bool
PSearchInfer::is_c1_match(pattern_index_type pattern_index)
{
node_const_iterator_type nitor = pattern_index->get_c1_nodes_begin();
node_const_iterator_type nend = pattern_index->get_c1_nodes_end();
for(; nitor!=nend; ++nitor) {
if(not m_psession_p->has_node(*nitor)) return false;
}
return true;
};
/////////////////////////////////////////////////////////////////////////////////////////
// PSearchInfer::select_relevant_patterns
//
/////////////////////////////////////////////////////////////////////////////////////////
void
PSearchInfer::select_relevant_patterns(infer_execution_state & query_state)
{
node_set_type & inferred_node_set = query_state.inferred_node_set;
psc_queue_type & psc_queue = query_state.psc_queue;
pattern_set_type & visited_pattern_set = query_state.visited_pattern_set;
node_const_iterator_type itor = inferred_node_set.begin();
node_const_iterator_type end = inferred_node_set.end();
for(; itor!=end; ++itor) {
node_index_type node_index = *itor;
if(!node_index) continue;
if(not node_index->is_skip_pattern_activation()) {
pattern_const_iterator_type pitor = node_index->get_patterns_begin();
pattern_const_iterator_type pend = node_index->get_patterns_end();
for (; pitor!=pend; ++pitor) {
pattern_index_type pattern_index = *pitor;
if(visited_pattern_set.find(pattern_index) == visited_pattern_set.end()) {
psc_ptr_type psc_p(new PatternScoreCard(pattern_index));
psc_queue.push(psc_p);
psc_p->set_asserted(false);
visited_pattern_set.insert(pattern_index);
}
}
}
}
// clear the inferred node list since we don't want to duplicate the psc
inferred_node_set.clear();
};
/////////////////////////////////////////////////////////////////////////////////////////
// PSearchInfer::is_negated_by_has_not_nodes
//
/////////////////////////////////////////////////////////////////////////////////////////
bool
PSearchInfer::is_negated_by_has_not_nodes(pattern_index_type pattern_index)
{
node_const_iterator_type itor = m_psession_p->get_negated_nodes_begin();
node_const_iterator_type end = m_psession_p->get_negated_nodes_end();
for(; itor!=end; ++itor) {
node_index_type node_index = *itor;
if(pattern_index->is_c2_node(node_index)) return true;
}
return false;
};
/////////////////////////////////////////////////////////////////////////////////////////
// PSearchInfer::is_negated_by_has_nodes
//
/////////////////////////////////////////////////////////////////////////////////////////
bool
PSearchInfer::is_negated_by_has_nodes(infer_execution_state & query_state, psc_ptr_type psc_p)
{
negate_pattern_map_type & filtered_pattern_map = query_state.filtered_pattern_map;
bool is_negated = false;
pattern_index_type pattern_index = psc_p->get_pattern_index();
node_dual_set_const_iterator itor = m_psession_p->get_nodes_iterator();
while(not itor.is_end()) {
node_index_type node_index = itor.get_value_by_value();
if(pattern_index->is_negated_by_node(node_index)) {
filtered_pattern_map.insert(std::make_pair(node_index, psc_p));
is_negated = true;
}
itor.next();
}
return is_negated;
};
/////////////////////////////////////////////////////////////////////////////////////////
// PSearchInfer::infer_pattern
//
/////////////////////////////////////////////////////////////////////////////////////////
void
PSearchInfer::infer_pattern(infer_execution_state & query_state, psc_ptr_type psc_p)
{
node_set_type & inferred_node_set = query_state.inferred_node_set;
negate_pattern_map_type & negate_pattern_map = query_state.negate_pattern_map;
psc_list_type & negated_psc_list = query_state.negated_psc_list;
unsigned int nbr_asserted_nodes = m_psession_p->get_nbr_nodes();
pattern_index_type pattern_index = psc_p->get_pattern_index();
// for each c2 nodes of the pattern, create/update the node score card
unsigned int nbr_match_nodes = 0;
node_const_iterator_type nitor = pattern_index->get_c2_nodes_begin();
node_const_iterator_type nend = pattern_index->get_c2_nodes_end();
for(; nitor!=nend; ++nitor) {
node_index_type node_index = *nitor;
// keep inferred node, even if this node was already seen in case any
// of it's related pattern was negated and need to be re considered.
inferred_node_set.insert(node_index);
nsc_ptr_type nsc_p = m_result_p->get_nsc_ptr(node_index);
if(!nsc_p) {
// create the node score card
nsc_p = m_result_p->add_nsc(node_index);
nsc_p->add_frequency(1);
nsc_p->set_has_node(true);
nsc_p->set_asserted_node(false);
nsc_p->add_weight(pattern_index->get_weight());
// add inferred node into psearch session has nodes
m_psession_p->add_inferred_node(node_index);
// check for negated patterns
negate_pattern_map_type::const_iterator itor;
negate_pattern_map_type::const_iterator end;
boost::tie(itor, end) = negate_pattern_map.equal_range(node_index);
for(; itor!=end; ++itor) {
negated_psc_list.push_back(itor->second);
}
} else {
// update the node score card
if(nsc_p->is_asserted_node()) nbr_match_nodes++;
nsc_p->add_frequency(1);
nsc_p->add_weight(pattern_index->get_weight());
}
}
// update the pattern score card and add to m_result_p
psc_p->set_all_session_nodes_related(nbr_match_nodes == nbr_asserted_nodes);
m_result_p->add_psc(psc_p);
// index the pattern by it's negated nodes
nitor = pattern_index->get_negated_nodes_begin();
nend = pattern_index->get_negated_nodes_end();
for(; nitor!=nend; ++nitor) {
negate_pattern_map.insert(std::make_pair(*nitor, psc_p));
}
};
/////////////////////////////////////////////////////////////////////////////////////////
// PSearchInfer::retract_pattern
//
/////////////////////////////////////////////////////////////////////////////////////////
void
PSearchInfer::retract_pattern(infer_execution_state & query_state, psc_ptr_type psc_p)
{
psc_list_type & negated_psc_list = query_state.negated_psc_list;
pattern_set_type & visited_pattern_set = query_state.visited_pattern_set;
psc_queue_type & psc_queue = query_state.psc_queue;
negate_pattern_map_type & filtered_pattern_map = query_state.filtered_pattern_map;
pattern_index_type pattern_index = psc_p->get_pattern_index();
// for each c2 nodes of the pattern, update the node score card
node_const_iterator_type nitor = pattern_index->get_c2_nodes_begin();
node_const_iterator_type nend = pattern_index->get_c2_nodes_end();
for(; nitor!=nend; ++nitor) {
node_index_type node_index = *nitor;
nsc_ptr_type nsc_p = m_result_p->get_nsc_ptr(node_index);
if(nsc_p) {
nsc_p->sub_frequency(1);
nsc_p->sub_weight(pattern_index->get_weight());
if(nsc_p->get_frequency()==0 and not nsc_p->is_asserted_node()) {
// retract all c1 patterns
pattern_const_iterator_type pitor = node_index->get_patterns_begin();
pattern_const_iterator_type pend = node_index->get_patterns_end();
for (; pitor!=pend; ++pitor) {
psc_ptr_type p = m_result_p->get_psc_ptr(*pitor);
if(p) negated_psc_list.push_back(p);
}
// remove from psearch session inferred has node set
m_psession_p->remove_node(node_index);
// remove the nsc from result
m_result_p->erase_nsc(node_index);
// restore patterns that were filtered out by this node
negate_pattern_map_type::const_iterator itor;
negate_pattern_map_type::const_iterator end;
boost::tie(itor, end) = filtered_pattern_map.equal_range(node_index);
for(; itor!=end; ++itor) {
psc_ptr_type p = itor->second;
psc_queue.push(p);
visited_pattern_set.insert(p->get_pattern_index());
}
}
}
}
// remove the psc from the result
m_result_p->erase_psc(pattern_index);
// remove from visited pattern since is retracted
visited_pattern_set.erase(pattern_index);
};
/////////////////////////////////////////////////////////////////////////////////////////
// PSearchInfer::printQuery
//
/////////////////////////////////////////////////////////////////////////////////////////
void
PSearchInfer::printQuery()
{
std::cout << "\nInfer's asserted selected type are ";
node_const_iterator_type itor = m_psession_p->get_nodes_begin();
node_const_iterator_type end = m_psession_p->get_nodes_end();
for(; itor!=end; ++itor) std::cout << "'" << (*itor)->get_name() << "' ";
std::cout << "\nInfer's inferred selected nodes are: ";
itor = m_psession_p->get_inferred_nodes_begin();
end = m_psession_p->get_inferred_nodes_end();
for(; itor!=end; ++itor) std::cout << "'" << (*itor)->get_name() << "' ";
std::cout << "\nInfer's negated nodes are: ";
itor = m_psession_p->get_negated_nodes_begin();
end = m_psession_p->get_negated_nodes_end();
for(; itor!=end; ++itor) std::cout << "'" << (*itor)->get_name() << "' ";
std::cout << std::endl;
std::cout << std::endl;
};
}; /* psearch namespace */