summaryrefslogtreecommitdiff
path: root/python/openvino/runtime/common/demo_utils/include/utils/input_wrappers.hpp
blob: eff38a72771673884cedf2dfa712518ca233579d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <list>
#include <memory>
#include <set>
#include <thread>
#include <vector>
#include <queue>

#include <opencv2/opencv.hpp>

class InputChannel;

class IInputSource {
public:
    virtual bool read(cv::Mat& mat, const std::shared_ptr<InputChannel>& caller) = 0;
    virtual void addSubscriber(const std::weak_ptr<InputChannel>& inputChannel) = 0;
    virtual cv::Size getSize() = 0;
    virtual void lock() {
        sourceLock.lock();
    }
    virtual void unlock() {
        sourceLock.unlock();
    }
    virtual ~IInputSource() = default;
private:
    std::mutex sourceLock;
};

class InputChannel: public std::enable_shared_from_this<InputChannel> { // note: public inheritance
public:
    InputChannel(const InputChannel&) = delete;
    InputChannel& operator=(const InputChannel&) = delete;
    static std::shared_ptr<InputChannel> create(const std::shared_ptr<IInputSource>& source) {
        auto tmp = std::shared_ptr<InputChannel>(new InputChannel(source));
        source->addSubscriber(tmp);
        return tmp;
    }
    bool read(cv::Mat& mat) {
        readQueueMutex.lock();
        if (readQueue.empty()) {
            readQueueMutex.unlock();
            source->lock();
            readQueueMutex.lock();
            if (readQueue.empty()) {
                bool res = source->read(mat, shared_from_this());
                readQueueMutex.unlock();
                source->unlock();
                return res;
            } else {
                source->unlock();
            }
        }
        mat = readQueue.front().clone();
        readQueue.pop();
        readQueueMutex.unlock();
        return true;
    }
    void push(const cv::Mat& mat) {
        readQueueMutex.lock();
        readQueue.push(mat);
        readQueueMutex.unlock();
    }
    cv::Size getSize() {
        return source->getSize();
    }

private:
    explicit InputChannel(const std::shared_ptr<IInputSource>& source): source{source} {}
    std::shared_ptr<IInputSource> source;
    std::queue<cv::Mat, std::list<cv::Mat>> readQueue;
    std::mutex readQueueMutex;
};

class VideoCaptureSource: public IInputSource {
public:
    VideoCaptureSource(const cv::VideoCapture& videoCapture, bool loop): videoCapture{videoCapture}, loop{loop},
        imSize{static_cast<int>(videoCapture.get(cv::CAP_PROP_FRAME_WIDTH)), static_cast<int>(videoCapture.get(cv::CAP_PROP_FRAME_HEIGHT))} {}
    bool read(cv::Mat& mat, const std::shared_ptr<InputChannel>& caller) override {
        if (!videoCapture.read(mat)) {
            if (loop) {
                videoCapture.set(cv::CAP_PROP_POS_FRAMES, 0);
                videoCapture.read(mat);
            } else {
                return false;
            }
        }
        if (1 != subscribedInputChannels.size()) {
            cv::Mat shared = mat.clone();
            for (const std::weak_ptr<InputChannel>& weakInputChannel : subscribedInputChannels) {
                try {
                    std::shared_ptr<InputChannel> sharedInputChannel = std::shared_ptr<InputChannel>(weakInputChannel);
                    if (caller != sharedInputChannel) {
                        sharedInputChannel->push(shared);
                    }
                } catch (const std::bad_weak_ptr&) {}
            }
        }
        return true;
    }
    void addSubscriber(const std::weak_ptr<InputChannel>& inputChannel) override {
        subscribedInputChannels.push_back(inputChannel);
    }
    cv::Size getSize() override {
        return imSize;
    }

private:
    std::vector<std::weak_ptr<InputChannel>> subscribedInputChannels;
    cv::VideoCapture videoCapture;
    bool loop;
    cv::Size imSize;
};

class ImageSource: public IInputSource {
public:
    ImageSource(const cv::Mat& im, bool loop): im{im.clone()}, loop{loop} {}  // clone to avoid image changing
    bool read(cv::Mat& mat, const std::shared_ptr<InputChannel>& caller) override {
        if (!loop) {
            auto subscribedInputChannelsIt = subscribedInputChannels.find(caller);
            if (subscribedInputChannels.end() == subscribedInputChannelsIt) {
                return false;
            } else {
                subscribedInputChannels.erase(subscribedInputChannelsIt);
                mat = im;
                return true;
            }
        } else {
            mat = im;
            return true;
        }
    }
    void addSubscriber(const std::weak_ptr<InputChannel>& inputChannel) override {
        if (false == subscribedInputChannels.insert(inputChannel).second)
            throw std::invalid_argument("The insertion did not take place");
    }
    cv::Size getSize() override {
        return im.size();
    }

private:
    std::set<std::weak_ptr<InputChannel>, std::owner_less<std::weak_ptr<InputChannel>>> subscribedInputChannels;
    cv::Mat im;
    bool loop;
};