#ifndef PSEARCHQUERY_H_
#define PSEARCHQUERY_H_
#include "psearch_db.h"
namespace psearch {
class PSearchSession;
class PSearchInfer;
class PSResult;
typedef std::tr1::shared_ptr<PSResult> psresult_ptr_type;
class NodeScoreCard;
typedef std::tr1::shared_ptr<NodeScoreCard> nsc_ptr_type;
class PatternScoreCard;
typedef std::tr1::shared_ptr<PatternScoreCard> psc_ptr_type;
/////////////////////////////////////////////////////////////////////////////////////////
// class NodeScoreCard
//
/////////////////////////////////////////////////////////////////////////////////////////
class NodeScoreCard
{
public:
NodeScoreCard(node_index_type node_idx):
m_node_index(node_idx),
m_freq(0),
m_weight(0),
m_flags(none_e)
{};
inline node_index_type
get_node_index()const
{
return m_node_index;
};
inline index_type
get_rdf_index()const
{
return m_node_index->get_rdf_index();
};
inline std::string
get_index_name()const
{
return m_node_index->get_index_name();
};
inline std::string const&
get_name()const
{
return m_node_index->get_name();
};
inline unsigned int
get_frequency()const
{
return m_freq;
};
inline void
add_frequency(unsigned int f=1)
{
m_freq += f;
};
inline void
sub_frequency(unsigned int f=1)
{
if(f > m_freq) m_freq = 0;
else m_freq -= f;
};
inline unsigned int
get_weight()const
{
return m_weight;
};
inline void
add_weight(unsigned int w=1)
{
m_weight += w;
};
inline void
sub_weight(unsigned int w=1)
{
if(w > m_weight) m_weight = 0;
else m_weight -= w;
};
inline bool
is_has_node()const
{
return m_flags & has_e;
};
inline void
set_has_node(bool b)
{
m_flags = m_flags & ~has_e;
if(b) m_flags = m_flags | has_e;
};
inline bool
is_asserted_node()const
{
return m_flags & asserted_e;
};
inline void
set_asserted_node(bool b)
{
m_flags = m_flags & ~asserted_e;
if(b) m_flags = m_flags | asserted_e;
};
private:
friend std::ostream& operator<<(std::ostream& out, NodeScoreCard const& nsc);
friend std::ostream& operator<<(std::ostream& out, NodeScoreCard const* nsc_p);
friend std::ostream& operator<<(std::ostream& out, nsc_ptr_type nsc_p);
node_index_type m_node_index;
unsigned int m_freq; // how often the node is relevant
unsigned int m_weight; // cummulative pattern weigth from all related patterns
unsigned int m_flags;
};
inline std::ostream& operator<<(std::ostream& out, NodeScoreCard const& nsc)
{
out << "<NSC index=\"" << nsc.get_index_name()
<< "\" f=\"" << nsc.m_freq
<< "\" p=\"" << (nsc.is_has_node() ? "has":"has_not")
<< "\" a=\"" << (nsc.is_asserted_node() ? "asserted":"inferred")
<< "\" w=\"" << nsc.m_weight << "\"/>";
return out;
};
inline std::ostream& operator<<(std::ostream& out, NodeScoreCard const* nsc_p)
{
out << *nsc_p;
return out;
};
inline std::ostream& operator<<(std::ostream& out, nsc_ptr_type nsc_p)
{
out << &*nsc_p;
return out;
};
/////////////////////////////////////////////////////////////////////////////////////////
// class PatternScoreCard
//
/////////////////////////////////////////////////////////////////////////////////////////
class PatternScoreCard
{
public:
PatternScoreCard(pattern_index_type pattern_idx):
m_pattern_index(pattern_idx),
m_flags(none_e)
{};
inline pattern_index_type
get_pattern_index()const
{
return m_pattern_index;
};
inline index_type
get_rdf_index()const
{
return m_pattern_index->get_rdf_index();
};
inline std::string
get_index_name()const
{
return m_pattern_index->get_index_name();
};
inline std::string const&
get_name()const
{
return m_pattern_index->get_name();
};
inline bool
is_all_session_nodes_related()const
{
return m_flags & all_related_e;
};
inline void
set_all_session_nodes_related(bool b)
{
m_flags = m_flags & ~all_related_e;
if(b) m_flags = m_flags | all_related_e;
};
inline bool
is_asserted()const
{
return m_flags & asserted_e;
};
inline void
set_asserted(bool b)
{
m_flags = m_flags & ~asserted_e;
if(b) m_flags = m_flags | asserted_e;
};
inline bool
is_c1_match()const
{
return m_flags & c1_match_e;
};
inline void
set_c1_match(bool b)
{
m_flags = m_flags & ~c1_match_e;
if(b) m_flags = m_flags | c1_match_e;
};
private:
friend std::ostream& operator<<(std::ostream& out, PatternScoreCard const& psc);
friend std::ostream& operator<<(std::ostream& out, PatternScoreCard const* psc_p);
friend std::ostream& operator<<(std::ostream& out, psc_ptr_type psc_p);
pattern_index_type m_pattern_index;
unsigned int m_flags;
};
inline std::ostream& operator<<(std::ostream& out, PatternScoreCard const& psc)
{
out << "<PSC index=\"" << psc.get_index_name()
<< "\" a=\"" << (psc.is_asserted() ? "asserted":"inferred")
<< "\" c1=\"" << (psc.is_c1_match() ? "true":"false")
<< "\" r=\"" << (psc.is_all_session_nodes_related() ? "all_related":"some_related") << "\"/>";
return out;
};
inline std::ostream& operator<<(std::ostream& out, PatternScoreCard const* psc_p)
{
out << "<PSC index=\"" << *psc_p;
return out;
};
inline std::ostream& operator<<(std::ostream& out, psc_ptr_type psc_p)
{
out << &*psc_p;
return out;
};
typedef std::tr1::unordered_map<node_index_type, nsc_ptr_type> nsc_map_type;
typedef std::list<psc_ptr_type> psc_list_type;
typedef std::tr1::unordered_map<pattern_index_type, psc_ptr_type> psc_map_type;
// create iterators that comply to the semantic: is_end(), get_value(), next()
typedef rdf::top_map_iterator<nsc_map_type::const_iterator, nsc_ptr_type> nsc_map_const_iterator;
typedef rdf::top_map_iterator<psc_map_type::const_iterator, psc_ptr_type> psc_map_const_iterator;
/////////////////////////////////////////////////////////////////////////////////////////
// class query_params
//
/////////////////////////////////////////////////////////////////////////////////////////
class query_params
{
public:
query_params(psearch_db_ptr_type db_p):
m_db_p(db_p),
m_categories(),
m_keep_session_nodes(false),
m_select_relevant_patterns_only(false)
{};
~query_params(){};
inline bool
is_keep_session_nodes()const
{
return m_keep_session_nodes;
};
inline void
set_keep_session_nodes(bool b)
{
m_keep_session_nodes = b;
};
inline bool
select_relevant_patterns_only()const
{
return m_select_relevant_patterns_only;
};
inline void
set_select_relevant_patterns_only(bool b)
{
m_select_relevant_patterns_only = b;
};
inline void
add_category(std::string const& name)
{
if(not m_db_p->has_category(name)) return;
m_categories.insert(m_db_p->get_category_index(name));
};
inline void
clear()
{
m_categories.clear();
};
inline bool
has_category(category_index_type category_index)const
{
return m_categories.find(category_index) != m_categories.end();
};
inline category_const_iterator_type
get_categories_begin()const
{
return m_categories.begin();
};
inline category_const_iterator_type
get_categories_end()const
{
return m_categories.end();
};
inline category_set_const_iterator
get_categories_iterator()const
{
return category_set_const_iterator(get_categories_begin(), get_categories_end());
};
inline size_t
get_nbr_categories()const
{
return m_categories.size();
};
private:
psearch_db_ptr_type m_db_p;
category_set_type m_categories;
bool m_keep_session_nodes;
bool m_select_relevant_patterns_only;
};
/////////////////////////////////////////////////////////////////////////////////////////
// class PSResult
//
/////////////////////////////////////////////////////////////////////////////////////////
class PSResult
{
public:
PSResult(PSearchDB const* db_p, PSearchSession * psession_p):
m_db_p(db_p),
m_psession_p(psession_p)
{};
inline unsigned int
get_nbr_nsc()const
{
return m_nsc_map.size();
};
inline unsigned int
get_nbr_psc()const
{
return m_psc_map.size();
};
inline nsc_map_const_iterator
get_nsc_ptr_iterator()const
{
return nsc_map_const_iterator(m_nsc_map.begin(), m_nsc_map.end());
};
inline psc_map_const_iterator
get_psc_ptr_iterator()const
{
return psc_map_const_iterator(m_psc_map.begin(), m_psc_map.end());
};
inline nsc_ptr_type
get_nsc_ptr(node_index_type node_index)const
{
nsc_map_type::const_iterator itor = m_nsc_map.find(node_index);
if(itor == m_nsc_map.end()) return nsc_ptr_type();
return itor->second;
};
inline psc_ptr_type
get_psc_ptr(pattern_index_type pattern_index)const
{
psc_map_type::const_iterator itor = m_psc_map.find(pattern_index);
if(itor == m_psc_map.end()) return psc_ptr_type();
return itor->second;
};
void
apply_filter(query_params const& params);
protected:
inline nsc_ptr_type
add_nsc(node_index_type node_index)
{
nsc_ptr_type nsc_p(new NodeScoreCard(node_index));
m_nsc_map.insert(std::make_pair(node_index, nsc_p));
return nsc_p;
};
inline void
erase_nsc(node_index_type node_index)
{
nsc_map_type::iterator pos = m_nsc_map.find(node_index);
if(pos != m_nsc_map.end()) m_nsc_map.erase(pos);
};
inline void
add_psc(psc_ptr_type const& psc_p)
{
m_psc_map.insert(std::make_pair(psc_p->get_pattern_index(), psc_p));
};
inline void
erase_psc(psc_map_type::iterator pos)
{
m_psc_map.erase(pos);
};
inline void
erase_psc(pattern_index_type pattern_index)
{
psc_map_type::iterator pos = m_psc_map.find(pattern_index);
if(pos != m_psc_map.end()) m_psc_map.erase(pos);
};
inline bool
has_psc(pattern_index_type pattern_index)const
{
if(!pattern_index) return false;
return m_psc_map.find(pattern_index) != m_psc_map.end();
};
inline psc_map_type::iterator
psc_iterator_begin()
{
return m_psc_map.begin();
};
inline psc_map_type::iterator
psc_iterator_end()
{
return m_psc_map.end();
};
private:
friend class PSearchQuery;
friend class PSearchInfer;
friend bool isResultValid(std::vector<std::string> &errors, PSearchSession * m_psession_p, psresult_ptr_type m_result_p);
PSearchDB const* m_db_p;
PSearchSession * m_psession_p;
nsc_map_type m_nsc_map;
psc_map_type m_psc_map;
};
/////////////////////////////////////////////////////////////////////////////////////////
// class query_execution_state
//
/////////////////////////////////////////////////////////////////////////////////////////
struct query_execution_state
{
psc_list_type psc_list;
};
/////////////////////////////////////////////////////////////////////////////////////////
// class PSearchQuery
//
/////////////////////////////////////////////////////////////////////////////////////////
class PSearchQuery
{
protected:
bool
is_c1_match(pattern_index_type patern_index);
void
select_relevant_patterns(query_execution_state & query_state);
void
select_all_patterns(query_execution_state & query_state);
void
prune_patterns_from_has_not_nodes(query_execution_state & query_state);
void
prune_patterns_from_has_nodes(query_execution_state & query_state);
void
collect_nodes_from_patterns(query_params const& params, query_execution_state & query_state);
void
infer_nsc_into_graph(query_execution_state & query_state);
public:
PSearchQuery(PSearchDB const* db_p, PSearchSession * psession_p, unsigned int max_loop=1000):
m_db_p(db_p),
m_psession_p(psession_p),
m_result_p(),
m_max_loop(max_loop)
{
m_result_p = psresult_ptr_type(new PSResult(db_p, psession_p));
};
~PSearchQuery(){};
void printQuery();
psresult_ptr_type
query(query_params const& params, bool verbose);
psresult_ptr_type
query(category_index_type category_index, bool verbose);
inline psresult_ptr_type
get_result()const
{
return m_result_p;
};
private:
PSearchDB const* m_db_p;
PSearchSession * m_psession_p;
psresult_ptr_type m_result_p;
unsigned int m_max_loop;
};
/////////////////////////////////////////////////////////////////////////////////////////
// whyNode
//
// Populate expanation list with the reasons why node was returned in node score card list
/////////////////////////////////////////////////////////////////////////////////////////
void
whyNode(std::vector<std::string> &explanation, node_index_type node_index, PSearchSession * m_psession_p, psresult_ptr_type m_result_p);
/////////////////////////////////////////////////////////////////////////////////////////
// whyPattern
//
// Populate expanation list with the reasons why pattern was returned in pattern score card list
/////////////////////////////////////////////////////////////////////////////////////////
void
whyPattern(std::vector<std::string> &explanation, pattern_index_type pindex, PSearchSession * m_psession_p, psresult_ptr_type m_result_p);
/////////////////////////////////////////////////////////////////////////////////////////
// isResultValid
//
// This checks if the obtained result is valid or not. This is for testing purpose.
// return true if result is validated and print information to output if verbose is true
/////////////////////////////////////////////////////////////////////////////////////////
bool
isResultValid(std::vector<std::string> &errors, PSearchSession * m_psession_p, psresult_ptr_type m_result_p);
/*
* Compute the result using a naive implementation.
* Used for comparing performance of implementation and to validate the result
*/
void computeResults(PSearchDB const* db_p, PSearchSession const* psession_p, nsc_map_type & xnsc, psc_map_type & xpsc);
}; /* psearch namespace */
#endif /*PSEARCHQUERY_H_*/