如何用tensorrt实现两个维度大小不同的张量点乘mul

请问有人知道怎么用tensorrt的api实现两个维度大小不同的张量点乘吗,比如a.shape=[64, 1, 1], b.shape=[64, 240, 320], 维度顺序是(C,H,W), a×b的shape是[64, 240, 320], 如何用tensorrt的api计算a×b,也就是pytorch里的torch.mul

在TensorRT中,可以使用plugin来自定义计算算法,实现两个维度大小不同的张量的点乘操作。具体流程如下:

  1. 实现一个自定义的TensorRT插件,可以继承IPluginV2接口。在实现该插件时,需要定义插件输入和输出的数据格式(data format),以及插件需要的配置。

  2. 在插件的实现中,可以直接获取输入和输出tensor的指针,然后利用循环遍历的方式计算点乘操作。

下面是一个实现的例子:

#include "NvInfer.h"
#include <cstdio>

using namespace nvinfer1;

class MultiplyPlugin : public IPluginV2
{
public:
    MultiplyPlugin() {}
    MultiplyPlugin(const void* data, size_t length)
    {
        const char *d = reinterpret_cast<const char*>(data), *a = d;
        mInputDims.nbDims = read<int>(d);
        for (int i = 0; i < mInputDims.nbDims; ++i)
            mInputDims.d[i] = read<int>(d);
        mOutputDims.nbDims = read<int>(d);
        for (int i = 0; i < mOutputDims.nbDims; ++i)
            mOutputDims.d[i] = read<int>(d);
        assert(d == a + length);
    }
    ~MultiplyPlugin() {}

    int getNbOutputs() const override { return 1; }
    Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override { return mOutputDims; }
    bool supportsFormat(DataType type, PluginFormat format) const override { return (type == DataType::kFLOAT && format == PluginFormat::kLINEAR); }
    void configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputs, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override
    {
        mDataType = type;
    }
    int initialize() override { return 0; }
    void terminate() override {}
    size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
    int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override
    {
        const float* input = reinterpret_cast<const float*>(inputs[0]);
        const float* weight = reinterpret_cast<const float*>(inputs[1]);
        float* output = reinterpret_cast<float*>(outputs[0]);
        const int inputSize = mInputDims.d[0];
        const int outputSize = mOutputDims.d[0] * mOutputDims.d[1] * mOutputDims.d[2];
        for (int n = 0; n < batchSize; ++n)
        {
            for (int c = 0; c < mOutputDims.d[0]; ++c)
            {
                for (int h = 0; h < mOutputDims.d[1]; ++h)
                {
                    for (int w = 0; w < mOutputDims.d[2]; ++w)
                    {
                        const int inputIndex = (n * inputSize) + c;
                        const int weightIndex = (n * outputSize) + (c * mOutputDims.d[1] * mOutputDims.d[2]) + (h * mOutputDims.d[2]) + w;
                        output[weightIndex] = input[inputIndex] * weight[weightIndex];
                    }
                }
            }
        }
        return 0;
    }
    size_t getSerializationSize() const override
    {
        return sizeof(int)*(1 + mInputDims.nbDims + mOutputDims.nbDims);
    }
    void serialize(void* buffer) const override
    {
        char *d = reinterpret_cast<char*>(buffer), *a = d;
        write(d, mInputDims.nbDims);
        for (int i = 0; i < mInputDims.nbDims; ++i)
            write(d, mInputDims.d[i]);
        write(d, mOutputDims.nbDims);
        for (int i = 0; i < mOutputDims.nbDims; ++i)
            write(d, mOutputDims.d[i]);
        assert(d == a + getSerializationSize());
    }
    void destroy() override { delete this; }
    const char* getPluginType() const override { return "MultiplyPlugin"; }
    const char* getPluginVersion() const override { return "1.0"; }
    void setPluginNamespace(const char* pluginNamespace) override { mNameSpace = pluginNamespace; }
    const char* getPluginNamespace() const override { return mNameSpace.c_str(); }

private:
    template<typename _T>
    static void write(char*& buffer, const _T& val)
    {
        *reinterpret_cast<_T*>(buffer) = val;
        buffer += sizeof(_T);
    }
    template<typename _T>
    static _T read(const char*& buffer)
    {
        _T val = *reinterpret_cast<const _T*>(buffer);
        buffer += sizeof(_T);
        return val;
    }

    DataType mDataType = DataType::kFLOAT;
    Dims mInputDims, mOutputDims;
    std::string mNameSpace;
};

class MultiplyPluginCreator : public IPluginCreator
{
public:
    MultiplyPluginCreator()
    {
        mPluginAttributes.emplace_back(PluginField("in_depth", nullptr, PluginFieldType::kINT32, 1));
        mPluginAttributes.emplace_back(PluginField("in_height", nullptr, PluginFieldType::kINT32, 1));
        mPluginAttributes.emplace_back(PluginField("in_width", nullptr, PluginFieldType::kINT32, 1));
        mPluginAttributes.emplace_back(PluginField("out_depth", nullptr, PluginFieldType::kINT32, 1));
        mPluginAttributes.emplace_back(PluginField("out_height", nullptr, PluginFieldType::kINT32, 1));
        mPluginAttributes.emplace_back(PluginField("out_width", nullptr, PluginFieldType::kINT32, 1));
    }
    ~MultiplyPluginCreator() {}

    const char* getPluginName() const override { return "MultiplyPlugin"; }
    const char* getPluginVersion() const override { return "1.0"; }

    const PluginFieldCollection* getFieldNames() override { return &mPluginAttributes; }
    IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override
    {
        const PluginField* fields = fc->fields;
        int inDepth = 1, inHeight = 1, inWidth = 1;
        int outDepth = 1, outHeight = 1, outWidth = 1;
        for (int i = 0; i < fc->nbFields; ++i)
        {
            if (!strcmp(fields[i].name, "in_depth"))
                inDepth = *(int*)fields[i].data;
            if (!strcmp(fields[i].name, "in_height"))
                inHeight = *(int*)fields[i].data;
            if (!strcmp(fields[i].name, "in_width"))
                inWidth = *(int*)fields[i].data;
            if (!strcmp(fields[i].name, "out_depth"))
                outDepth = *(int*)fields[i].data;
            if (!strcmp(fields[i].name, "out_height"))
                outHeight = *(int*)fields[i].data;
            if (!strcmp(fields[i].name, "out_width"))
                outWidth = *(int*)fields[i].data;
        }
        Dims inputDims = Dims3(inDepth, inHeight, inWidth);
        Dims outputDims = Dims3(outDepth, outHeight, outWidth);
        MultiplyPlugin* plugin = new MultiplyPlugin();
        plugin->setPluginNamespace(mNamespace.c_str());
        plugin->initialize();
        plugin->configureWithFormat(&inputDims, 1, &outputDims, 1, DataType::kFLOAT, PluginFormat::kLINEAR, 1);
        return plugin;
    }
    IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override
    {
        MultiplyPlugin* plugin = new MultiplyPlugin(serialData, serialLength);
        plugin->setPluginNamespace(mNamespace.c_str());
        return plugin;
    }
    void setPluginNamespace(const char* libNamespace) override { mNamespace = libNamespace; }
    const char* getPluginNamespace() const override { return mNamespace.c_str(); }

private:
    std::string mNamespace;
    static PluginFieldCollection mPluginAttributes;
};

PluginFieldCollection MultiplyPluginCreator::mPluginAttributes;

extern "C" IPluginCreator& getPluginCreator()
{
    static MultiplyPluginCreator pluginCreator;
    return pluginCreator;
}

在上述代码中,自定义了一个名为MultiplyPlugin的插件,其中实现了自定义的点乘计算操作。该插件包含两个输入参数和一个输出参数,分别是输入张量、权重张量和输出张量。

接下来,可以在TensorRT中使用该自定义插件来实现两个维度大小不同的张量点乘。

// 创建Engine
IBuilder* builder = createInferBuilder(gLogger);
INetworkDefinition* network = builder->createNetworkV2(0U);
ITensor* a = network->addInput("a", DataType::kFLOAT, Dims3(1, 1, 64));
ITensor* b = network->addInput("b", DataType::kFLOAT, Dims3(320, 240, 64));
ITensor* ab[] = {a, b};
auto plugin = network->addPluginV2(ab, 2, createPlugin("MultiplyPlugin", pluginFactory));
ITensor* output = plugin->getOutput(0);
network->markOutput(*output);

在创建Engine时,需要调用createPlugin函数来实例化自定义插件,并将两个输入张量作为参数添加到插件中。创建Engine后,就可以像其他TensorRT网络一样使用了。

以上就是利用TensorRT的API实现两个大小不同张量点乘的步骤。

不知道你这个问题是否已经解决, 如果还没有解决的话:
  • 可以查看手册:pytorch mul_() (torch.Tensor method) 中的内容
  • 除此之外, 这篇博客: Pytorch Tensor的奇妙运算中的 2. torch.mul()函数 部分也许能够解决你的问题, 你可以仔细阅读以下内容或者直接跳转源博客中阅读:

    同型效果同点乘。

    当x与y同型时

    x.mul(y) == x*y

    对应点相乘

    不同型时,也会对较小者进行维度扩充。

  • 以下回答来自chatgpt:

    很抱歉,由于题目中并没有具体的问题描述,我无法提供任何解决方案。请您提供具体问题的描述,我会尽力为您解答。


如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^