$ cat nvdsiplugin_ssd.cpp #include "NvInferPlugin.h" #include <vector> #include "cuda_runtime_api.h" #include <cassert> #include <cublas_v2.h> #include <functional> #include <numeric> #include <algorithm> #include <iostream>
using namespace nvinfer1;
class FlattenConcat : public IPluginV2 { public: FlattenConcat(int concatAxis, bool ignoreBatch) : mIgnoreBatch(ignoreBatch) , mConcatAxisID(concatAxis) { assert(mConcatAxisID == 1 || mConcatAxisID == 2 || mConcatAxisID == 3); } //clone constructor FlattenConcat(int concatAxis, bool ignoreBatch, int numInputs, int outputConcatAxis, int* inputConcatAxis) : mIgnoreBatch(ignoreBatch) , mConcatAxisID(concatAxis) , mOutputConcatAxis(outputConcatAxis) , mNumInputs(numInputs) { CHECK(cudaMallocHost((void**) &mInputConcatAxis, mNumInputs * sizeof(int))); for (int i = 0; i < mNumInputs; ++i) mInputConcatAxis[i] = inputConcatAxis[i]; }
FlattenConcat(const void* data, size_t length) { } ~FlattenConcat() { } int getNbOutputs() const noexcept override { return 1; } Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept override { } int initialize() noexcept override { } void terminate() noexcept override { } size_t getWorkspaceSize(int) const noexcept override { return 0; } int enqueue(int batchSize, void const* const* inputs, void* const* outputs, void*, cudaStream_t stream) noexcept override { } size_t getSerializationSize() const noexcept override { } void serialize(void* buffer) const noexcept override { } void configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputDims, int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) noexcept override { } bool supportsFormat(DataType type, PluginFormat format) const noexcept override { } const char* getPluginType() const noexcept override { return "FlattenConcat_TRT"; } const char* getPluginVersion() const noexcept override { return "1"; } void destroy() noexcept override { delete this; } IPluginV2* clone() const noexcept override { } void setPluginNamespace(const char* libNamespace) noexcept override { mNamespace = libNamespace; } const char* getPluginNamespace() const noexcept override { return mNamespace.c_str(); }
private: template <typename T> void write(char*& buffer, const T& val) const { } template <typename T> T read(const char*& buffer) { } size_t* mCopySize = nullptr; bool mIgnoreBatch{false}; int mConcatAxisID{0}, mOutputConcatAxis{0}, mNumInputs{0}; int* mInputConcatAxis = nullptr; nvinfer1::Dims mCHW; cublasHandle_t mCublas; std::string mNamespace; };
namespace { const char* FLATTENCONCAT_PLUGIN_VERSION{"1"}; const char* FLATTENCONCAT_PLUGIN_NAME{"FlattenConcat_TRT"}; } // namespace
class FlattenConcatPluginCreator : public IPluginCreator { public: FlattenConcatPluginCreator() { mPluginAttributes.emplace_back(PluginField("axis", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("ignoreBatch", nullptr, PluginFieldType::kINT32, 1)); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); }
~FlattenConcatPluginCreator() {} const char* getPluginName() const noexcept override { return FLATTENCONCAT_PLUGIN_NAME; } const char* getPluginVersion() const noexcept override { return FLATTENCONCAT_PLUGIN_VERSION; } const PluginFieldCollection* getFieldNames() noexcept override { return &mFC; } IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override { } IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override { return new FlattenConcat(serialData, serialLength); } void setPluginNamespace(const char* libNamespace) noexcept override { mNamespace = libNamespace; } const char* getPluginNamespace() const noexcept override { return mNamespace.c_str(); }
private: static PluginFieldCollection mFC; bool mIgnoreBatch{false}; int mConcatAxisID; static std::vector<PluginField> mPluginAttributes; std::string mNamespace = ""; };
PluginFieldCollection FlattenConcatPluginCreator::mFC{}; std::vector<PluginField> FlattenConcatPluginCreator::mPluginAttributes;
REGISTER_TENSORRT_PLUGIN(FlattenConcatPluginCreator);
|