mlpack  2.0.1
neighbor_search_rules.hpp
Go to the documentation of this file.
1 
15 #ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
16 #define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
17 
18 #include "ns_traversal_info.hpp"
19 
20 namespace mlpack {
21 namespace neighbor {
22 
23 template<typename SortPolicy, typename MetricType, typename TreeType>
25 {
26  public:
27  NeighborSearchRules(const typename TreeType::Mat& referenceSet,
28  const typename TreeType::Mat& querySet,
29  arma::Mat<size_t>& neighbors,
30  arma::mat& distances,
31  MetricType& metric,
32  const bool sameSet = false);
41  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
42 
51  double Score(const size_t queryIndex, TreeType& referenceNode);
52 
64  double Rescore(const size_t queryIndex,
65  TreeType& referenceNode,
66  const double oldScore) const;
67 
76  double Score(TreeType& queryNode, TreeType& referenceNode);
77 
89  double Rescore(TreeType& queryNode,
90  TreeType& referenceNode,
91  const double oldScore) const;
92 
94  size_t BaseCases() const { return baseCases; }
96  size_t& BaseCases() { return baseCases; }
97 
99  size_t Scores() const { return scores; }
101  size_t& Scores() { return scores; }
102 
105 
107  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
109  TraversalInfoType& TraversalInfo() { return traversalInfo; }
110 
111  protected:
113  const typename TreeType::Mat& referenceSet;
114 
116  const typename TreeType::Mat& querySet;
117 
119  arma::Mat<size_t>& neighbors;
120 
122  arma::mat& distances;
123 
125  MetricType& metric;
126 
128  bool sameSet;
129 
135  double lastBaseCase;
136 
138  size_t baseCases;
140  size_t scores;
141 
144  TraversalInfoType traversalInfo;
145 
149  double CalculateBound(TreeType& queryNode) const;
150 
160  void InsertNeighbor(const size_t queryIndex,
161  const size_t pos,
162  const size_t neighbor,
163  const double distance);
164 };
165 
166 } // namespace neighbor
167 } // namespace mlpack
168 
169 // Include implementation.
170 #include "neighbor_search_rules_impl.hpp"
171 
172 #endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
TraversalInfoType & TraversalInfo()
Modify the traversal info.
size_t lastQueryIndex
The last query point BaseCase() was called with.
Linear algebra utility functions, generally performed on matrices or vectors.
arma::mat & distances
The matrix the resultant neighbor distances should be stored in.
TraversalInfoType traversalInfo
Traversal info for the parent combination; this is updated by the traversal before each call to Score...
size_t BaseCases() const
Get the number of base cases that have been performed.
NeighborSearchRules(const typename TreeType::Mat &referenceSet, const typename TreeType::Mat &querySet, arma::Mat< size_t > &neighbors, arma::mat &distances, MetricType &metric, const bool sameSet=false)
const TreeType::Mat & referenceSet
The reference set.
void InsertNeighbor(const size_t queryIndex, const size_t pos, const size_t neighbor, const double distance)
Insert a point into the neighbors and distances matrices; this is a helper function.
NeighborSearchTraversalInfo< TreeType > TraversalInfoType
Convenience typedef.
Traversal information for NeighborSearch.
size_t & Scores()
Modify the number of scores that have been performed.
double lastBaseCase
The last base case result.
size_t lastReferenceIndex
The last reference point BaseCase() was called with.
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
size_t Scores() const
Get the number of scores that have been performed.
arma::Mat< size_t > & neighbors
The matrix the resultant neighbor indices should be stored in.
const TraversalInfoType & TraversalInfo() const
Get the traversal info.
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
size_t & BaseCases()
Modify the number of base cases that have been performed.
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
double CalculateBound(TreeType &queryNode) const
Recalculate the bound for a given query node.
MetricType & metric
The instantiated metric.
size_t scores
The number of scores that have been performed.
size_t baseCases
The number of base cases that have been performed.
const TreeType::Mat & querySet
The query set.
bool sameSet
Denotes whether or not the reference and query sets are the same.