// This file is part of OpenCV project. // It is subject to the license terms in the LICENSE file found in the top-level directory // of this distribution and at http://opencv.org/license.html. // // Copyright (C) 2019 Intel Corporation // #ifndef OPENCV_GAPI_GPLAIDMLKERNEL_HPP #define OPENCV_GAPI_GPLAIDMLKERNEL_HPP #include #include namespace plaidml { namespace edsl { class Tensor; } // namespace edsl } // namespace plaidml namespace cv { namespace gapi { namespace plaidml { GAPI_EXPORTS cv::gapi::GBackend backend(); } // namespace plaidml } // namespace gapi struct GPlaidMLContext { // Generic accessor API template const T& inArg(int input) { return m_args.at(input).get(); } // Syntax sugar const plaidml::edsl::Tensor& inTensor(int input) { return inArg(input); } plaidml::edsl::Tensor& outTensor(int output) { return *(m_results.at(output).get()); } std::vector m_args; std::unordered_map m_results; }; class GAPI_EXPORTS GPlaidMLKernel { public: using F = std::function; GPlaidMLKernel() = default; explicit GPlaidMLKernel(const F& f) : m_f(f) {} void apply(GPlaidMLContext &ctx) const { GAPI_Assert(m_f); m_f(ctx); } protected: F m_f; }; namespace detail { template struct plaidml_get_in; template<> struct plaidml_get_in { static const plaidml::edsl::Tensor& get(GPlaidMLContext& ctx, int idx) { return ctx.inTensor(idx); } }; template struct plaidml_get_in { static T get(GPlaidMLContext &ctx, int idx) { return ctx.inArg(idx); } }; template struct plaidml_get_out; template<> struct plaidml_get_out { static plaidml::edsl::Tensor& get(GPlaidMLContext& ctx, int idx) { return ctx.outTensor(idx); } }; template struct PlaidMLCallHelper; template struct PlaidMLCallHelper, std::tuple > { template static void call_impl(GPlaidMLContext &ctx, detail::Seq, detail::Seq) { Impl::run(plaidml_get_in::get(ctx, IIs)..., plaidml_get_out::get(ctx, OIs)...); } static void call(GPlaidMLContext& ctx) { call_impl(ctx, typename detail::MkSeq::type(), typename detail::MkSeq::type()); } }; } // namespace detail template class GPlaidMLKernelImpl: public cv::detail::PlaidMLCallHelper, public cv::detail::KernelTag { using P = detail::PlaidMLCallHelper; public: using API = K; static cv::gapi::GBackend backend() { return cv::gapi::plaidml::backend(); } static cv::GPlaidMLKernel kernel() { return GPlaidMLKernel(&P::call); } }; #define GAPI_PLAIDML_KERNEL(Name, API) struct Name: public cv::GPlaidMLKernelImpl } // namespace cv #endif // OPENCV_GAPI_GPLAIDMLKERNEL_HPP