请问有人知道怎么用tensorrt的api实现两个维度大小不同的张量点乘吗,比如a.shape=[64, 1, 1], b.shape=[64, 240, 320], 维度顺序是(C,H,W), a×b的shape是[64, 240, 320], 如何用tensorrt的api计算a×b,也就是pytorch里的torch.mul
在TensorRT中,可以使用plugin来自定义计算算法,实现两个维度大小不同的张量的点乘操作。具体流程如下:
实现一个自定义的TensorRT插件,可以继承IPluginV2接口。在实现该插件时,需要定义插件输入和输出的数据格式(data format),以及插件需要的配置。
在插件的实现中,可以直接获取输入和输出tensor的指针,然后利用循环遍历的方式计算点乘操作。
下面是一个实现的例子:
#include "NvInfer.h"
#include <cstdio>
using namespace nvinfer1;
class MultiplyPlugin : public IPluginV2
{
public:
MultiplyPlugin() {}
MultiplyPlugin(const void* data, size_t length)
{
const char *d = reinterpret_cast<const char*>(data), *a = d;
mInputDims.nbDims = read<int>(d);
for (int i = 0; i < mInputDims.nbDims; ++i)
mInputDims.d[i] = read<int>(d);
mOutputDims.nbDims = read<int>(d);
for (int i = 0; i < mOutputDims.nbDims; ++i)
mOutputDims.d[i] = read<int>(d);
assert(d == a + length);
}
~MultiplyPlugin() {}
int getNbOutputs() const override { return 1; }
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override { return mOutputDims; }
bool supportsFormat(DataType type, PluginFormat format) const override { return (type == DataType::kFLOAT && format == PluginFormat::kLINEAR); }
void configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputs, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override
{
mDataType = type;
}
int initialize() override { return 0; }
void terminate() override {}
size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override
{
const float* input = reinterpret_cast<const float*>(inputs[0]);
const float* weight = reinterpret_cast<const float*>(inputs[1]);
float* output = reinterpret_cast<float*>(outputs[0]);
const int inputSize = mInputDims.d[0];
const int outputSize = mOutputDims.d[0] * mOutputDims.d[1] * mOutputDims.d[2];
for (int n = 0; n < batchSize; ++n)
{
for (int c = 0; c < mOutputDims.d[0]; ++c)
{
for (int h = 0; h < mOutputDims.d[1]; ++h)
{
for (int w = 0; w < mOutputDims.d[2]; ++w)
{
const int inputIndex = (n * inputSize) + c;
const int weightIndex = (n * outputSize) + (c * mOutputDims.d[1] * mOutputDims.d[2]) + (h * mOutputDims.d[2]) + w;
output[weightIndex] = input[inputIndex] * weight[weightIndex];
}
}
}
}
return 0;
}
size_t getSerializationSize() const override
{
return sizeof(int)*(1 + mInputDims.nbDims + mOutputDims.nbDims);
}
void serialize(void* buffer) const override
{
char *d = reinterpret_cast<char*>(buffer), *a = d;
write(d, mInputDims.nbDims);
for (int i = 0; i < mInputDims.nbDims; ++i)
write(d, mInputDims.d[i]);
write(d, mOutputDims.nbDims);
for (int i = 0; i < mOutputDims.nbDims; ++i)
write(d, mOutputDims.d[i]);
assert(d == a + getSerializationSize());
}
void destroy() override { delete this; }
const char* getPluginType() const override { return "MultiplyPlugin"; }
const char* getPluginVersion() const override { return "1.0"; }
void setPluginNamespace(const char* pluginNamespace) override { mNameSpace = pluginNamespace; }
const char* getPluginNamespace() const override { return mNameSpace.c_str(); }
private:
template<typename _T>
static void write(char*& buffer, const _T& val)
{
*reinterpret_cast<_T*>(buffer) = val;
buffer += sizeof(_T);
}
template<typename _T>
static _T read(const char*& buffer)
{
_T val = *reinterpret_cast<const _T*>(buffer);
buffer += sizeof(_T);
return val;
}
DataType mDataType = DataType::kFLOAT;
Dims mInputDims, mOutputDims;
std::string mNameSpace;
};
class MultiplyPluginCreator : public IPluginCreator
{
public:
MultiplyPluginCreator()
{
mPluginAttributes.emplace_back(PluginField("in_depth", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("in_height", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("in_width", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("out_depth", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("out_height", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("out_width", nullptr, PluginFieldType::kINT32, 1));
}
~MultiplyPluginCreator() {}
const char* getPluginName() const override { return "MultiplyPlugin"; }
const char* getPluginVersion() const override { return "1.0"; }
const PluginFieldCollection* getFieldNames() override { return &mPluginAttributes; }
IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override
{
const PluginField* fields = fc->fields;
int inDepth = 1, inHeight = 1, inWidth = 1;
int outDepth = 1, outHeight = 1, outWidth = 1;
for (int i = 0; i < fc->nbFields; ++i)
{
if (!strcmp(fields[i].name, "in_depth"))
inDepth = *(int*)fields[i].data;
if (!strcmp(fields[i].name, "in_height"))
inHeight = *(int*)fields[i].data;
if (!strcmp(fields[i].name, "in_width"))
inWidth = *(int*)fields[i].data;
if (!strcmp(fields[i].name, "out_depth"))
outDepth = *(int*)fields[i].data;
if (!strcmp(fields[i].name, "out_height"))
outHeight = *(int*)fields[i].data;
if (!strcmp(fields[i].name, "out_width"))
outWidth = *(int*)fields[i].data;
}
Dims inputDims = Dims3(inDepth, inHeight, inWidth);
Dims outputDims = Dims3(outDepth, outHeight, outWidth);
MultiplyPlugin* plugin = new MultiplyPlugin();
plugin->setPluginNamespace(mNamespace.c_str());
plugin->initialize();
plugin->configureWithFormat(&inputDims, 1, &outputDims, 1, DataType::kFLOAT, PluginFormat::kLINEAR, 1);
return plugin;
}
IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override
{
MultiplyPlugin* plugin = new MultiplyPlugin(serialData, serialLength);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
void setPluginNamespace(const char* libNamespace) override { mNamespace = libNamespace; }
const char* getPluginNamespace() const override { return mNamespace.c_str(); }
private:
std::string mNamespace;
static PluginFieldCollection mPluginAttributes;
};
PluginFieldCollection MultiplyPluginCreator::mPluginAttributes;
extern "C" IPluginCreator& getPluginCreator()
{
static MultiplyPluginCreator pluginCreator;
return pluginCreator;
}
在上述代码中,自定义了一个名为MultiplyPlugin的插件,其中实现了自定义的点乘计算操作。该插件包含两个输入参数和一个输出参数,分别是输入张量、权重张量和输出张量。
接下来,可以在TensorRT中使用该自定义插件来实现两个维度大小不同的张量点乘。
// 创建Engine
IBuilder* builder = createInferBuilder(gLogger);
INetworkDefinition* network = builder->createNetworkV2(0U);
ITensor* a = network->addInput("a", DataType::kFLOAT, Dims3(1, 1, 64));
ITensor* b = network->addInput("b", DataType::kFLOAT, Dims3(320, 240, 64));
ITensor* ab[] = {a, b};
auto plugin = network->addPluginV2(ab, 2, createPlugin("MultiplyPlugin", pluginFactory));
ITensor* output = plugin->getOutput(0);
network->markOutput(*output);
在创建Engine时,需要调用createPlugin函数来实例化自定义插件,并将两个输入张量作为参数添加到插件中。创建Engine后,就可以像其他TensorRT网络一样使用了。
以上就是利用TensorRT的API实现两个大小不同张量点乘的步骤。
不知道你这个问题是否已经解决, 如果还没有解决的话:同型效果同点乘。
当x与y同型时
x.mul(y) == x*y
对应点相乘
不同型时,也会对较小者进行维度扩充。
很抱歉,由于题目中并没有具体的问题描述,我无法提供任何解决方案。请您提供具体问题的描述,我会尽力为您解答。