最近在做NSFW识别,从github找到一个已经训练好的PB模型文件,使用python调用的,当尝试用C++调用的时候,发现参数不正确.通过python 生成日志,用TensorBoard查看日志时,发现模型输入参数是字符串??字符串类型是json,在我理解力,明明应该是一张图片才对?!python 脚本里也是这样传的。
gitee地址: https://gitee.com/yyj8209/CVSample/tree/master/TensorFlow/inception_model
github源地址:https://github.com/kingroc711/CVSample/tree/master/TensorFlow/inception_model
#include <iostream>
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/ops/image_ops.h"
//#include "eigen3/unsupported/Eigen/CXX11/Tensor"
//#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include <vector>
//using namespace tensorflow;
using namespace std;
using namespace tensorflow;
using tensorflow::Tensor;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32;
Status ReadTensorFromImageFile(string file_name, const int input_height,
const int input_width,
vector<Tensor>* out_tensors);
int main(int argc, char *argv[])
{
SessionOptions sessionOptions;
Session *session = NewSession(sessionOptions);
string modelPath = "/opt/work/build_work/TensorFlow/inception_model/output_graph.pb";
//tensorflow 官方模型
// modelPath = "/opt/work/c_work/qt_work/tensorflow_cc_demo/model/classify_image_graph_def.pb";
GraphDef graphDef;
Status statud_load = ReadBinaryProto(Env::Default(), modelPath, &graphDef);
if(statud_load.ok()) {
cout << "load pb file success : " << modelPath << endl;
}
cout << "node size:" << graphDef.node_size() << endl;
graphDef.node();
if( session->Create(graphDef).ok() ) {
cout << "success graph in session " << endl;
}
string image_path("/root/test.png");
// image_path = "/opt/work/c_work/qt_work/tensorflow_cc_demo/model/cropped_panda.jpg";
vector<Tensor> inputs;
if(ReadTensorFromImageFile(image_path, 100, 100, &inputs).ok()) {
cout << "image load success!" << endl;
cout << inputs.size() << endl;
}
vector<Tensor> outputs;
string input = "DecodeJpeg/contents:0";
string output = "final_result:0";
cout << inputs[0].DebugString() << endl;
pair<string, Tensor> img(input,inputs[0]);
vector<pair<string, tensorflow::Tensor>> runInputs = {
{"DecodeJpeg/contents:0", inputs[0]},
};
Status status = session->Run(runInputs, {output}, {}, &outputs);
cout << status << endl;
if (!status.ok()) {
cout << "run failed!" << endl;
}
cout << outputs.size() << endl;
cout << "hello wordl" << endl;
return 0;
}
Status ReadTensorFromImageFile(string file_name, const int input_height,
const int input_width,
vector<Tensor>* out_tensors) {
auto root = Scope::NewRootScope();
using namespace ops;
auto file_reader = ops::ReadFile(root.WithOpName("file_reader"),file_name);
const int wanted_channels = 1;
Output image_reader;
std::size_t found = file_name.find(".png");
//判断文件格式
if (found!=std::string::npos) {
image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,DecodePng::Channels(wanted_channels));
}
else {
image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,DecodeJpeg::Channels(wanted_channels));
}
// 下面几步是读取图片并处理
auto float_caster =Cast(root.WithOpName("float_caster"), image_reader, DT_FLOAT);
auto dims_expander = ExpandDims(root, float_caster, 0);
auto resized = ResizeBilinear(root, dims_expander,Const(root.WithOpName("resize"), {input_height, input_width}));
// Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),{input_std});
Transpose(root.WithOpName("transpose"),resized,{0,2,1,3});
GraphDef graph;
root.ToGraphDef(&graph);
unique_ptr<Session> session(NewSession(SessionOptions()));
session->Create(graph);
session->Run({}, {"transpose"}, {}, out_tensors);//Run,获取图片数据保存到Tensor中
return Status::OK();
}
Invalid argument: Expects arg[0] to be string but float is provided
希望用C++调过PB模型的,给出一些建议。我没有接触过太多tensorflow,并不清楚我现在想法 对不对
问题已解决,首先,传入参数是一张图片,类型确实是一个字符串,只不过类型是tstring,通过std::unique_ptrtensorflow::RandomAccessFile生成字符串。
tensorflow已经提供了针对C++预测图片提供了示例,示例地址:https://gitee.com/mirrors/tensorflow/blob/master/tensorflow/examples/label_image/main.cc
示例中ReadEntireFile函数就是专门处理把图片转为字符串数据的函数。
最后附上运行成功全部代码:
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include <vector>
using namespace std;
using namespace tensorflow;
using tensorflow::Tensor;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32;
static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
Tensor* output);
int main(void){
SessionOptions sessionOptions;
Session *session = NewSession(sessionOptions);
//pb文件路径
string modelPath = "/opt/work/build_work/TensorFlow/inception_model/output_graph.pb";
GraphDef graphDef;
Status statud_load = ReadBinaryProto(Env::Default(), modelPath, &graphDef);
if(statud_load.ok()) {
cout << "load pb file success : " << modelPath << endl;
}
if( session->Create(graphDef).ok() ) {
cout << "success graph in session " << endl;
}
vector<Tensor> outputs;
string input = "DecodeJpeg/contents:0";
string output = "final_result:0";
Tensor input0(DT_STRING, TensorShape());
//图片文件
if(ReadEntireFile(tensorflow::Env::Default(), "/root/test.png", &input0).ok()) {
cout << "图片读取成功!" << endl;
}
vector<pair<string, tensorflow::Tensor>> runInputs = {
{"DecodeJpeg/contents:0", input0},
};
//预测
Status status = session->Run(runInputs, {output}, {}, &outputs);
cout << status << endl;
if (!status.ok()) {
cout << "run failed!" << endl;
}
//处理输出结果,模型输出结果就是一维数组,按照索引0,1,2,3,4分别对应porn。neutral、hentai、drawings、sexy
Tensor scores;
scores = outputs[0];
tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
//scores_flat.size() 数量是 5,打印每一个分类分数
for(int i = 0; i < scores_flat.size(); i++) {
cout << scores_flat(i) << endl;
}
return 0;
}
static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
Tensor* output) {
tensorflow::uint64 file_size = 0;
TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
string contents;
contents.resize(file_size);
std::unique_ptr<tensorflow::RandomAccessFile> file;
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
tensorflow::StringPiece data;
TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0]));
if (data.size() != file_size) {
return tensorflow::errors::DataLoss("Truncated read of '", filename,
"' expected ", file_size, " got ",
data.size());
}
output->scalar<tstring>()() = tstring(data);
return Status::OK();
}
```c++
```