summaryrefslogtreecommitdiff
path: root/python/openvino/runtime/common/demo_utils/include/utils/kuhn_munkres.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'python/openvino/runtime/common/demo_utils/include/utils/kuhn_munkres.hpp')
-rw-r--r--python/openvino/runtime/common/demo_utils/include/utils/kuhn_munkres.hpp57
1 files changed, 57 insertions, 0 deletions
diff --git a/python/openvino/runtime/common/demo_utils/include/utils/kuhn_munkres.hpp b/python/openvino/runtime/common/demo_utils/include/utils/kuhn_munkres.hpp
new file mode 100644
index 0000000..6e6ac51
--- /dev/null
+++ b/python/openvino/runtime/common/demo_utils/include/utils/kuhn_munkres.hpp
@@ -0,0 +1,57 @@
+// Copyright (C) 2018-2019 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include "opencv2/core.hpp"
+
+#include <memory>
+#include <vector>
+
+
+///
+/// \brief The KuhnMunkres class
+///
+/// Solves the assignment problem.
+///
+class KuhnMunkres {
+public:
+ ///
+ /// \brief Initializes the class for assignment problem solving.
+ /// \param[in] greedy If a faster greedy matching algorithm should be used.
+ explicit KuhnMunkres(bool greedy = false);
+
+ ///
+ /// \brief Solves the assignment problem for given dissimilarity matrix.
+ /// It returns a vector that where each element is a column index for
+ /// corresponding row (e.g. result[0] stores optimal column index for very
+ /// first row in the dissimilarity matrix).
+ /// \param dissimilarity_matrix CV_32F dissimilarity matrix.
+ /// \return Optimal column index for each row. -1 means that there is no
+ /// column for row.
+ ///
+ std::vector<size_t> Solve(const cv::Mat &dissimilarity_matrix);
+
+private:
+ static constexpr int kStar = 1;
+ static constexpr int kPrime = 2;
+
+ cv::Mat dm_;
+ cv::Mat marked_;
+ std::vector<cv::Point> points_;
+
+ std::vector<int> is_row_visited_;
+ std::vector<int> is_col_visited_;
+
+ int n_;
+ bool greedy_;
+
+ void TrySimpleCase();
+ bool CheckIfOptimumIsFound();
+ cv::Point FindUncoveredMinValPos();
+ void UpdateDissimilarityMatrix(float val);
+ int FindInRow(int row, int what);
+ int FindInCol(int col, int what);
+ void Run();
+};