model_diagnostics.cpp 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. /*************************************************
  2. USAGE:
  3. ./model_diagnostics -m <model file location>
  4. **************************************************/
  5. #include <opencv2/dnn.hpp>
  6. #include <opencv2/core/utils/filesystem.hpp>
  7. #include <opencv2/dnn/utils/debug_utils.hpp>
  8. #include <iostream>
  9. using namespace cv;
  10. using namespace dnn;
  11. static
  12. int diagnosticsErrorCallback(int /*status*/, const char* /*func_name*/,
  13. const char* /*err_msg*/, const char* /*file_name*/,
  14. int /*line*/, void* /*userdata*/)
  15. {
  16. fflush(stdout);
  17. fflush(stderr);
  18. return 0;
  19. }
  20. static std::string checkFileExists(const std::string& fileName)
  21. {
  22. if (fileName.empty() || utils::fs::exists(fileName))
  23. return fileName;
  24. CV_Error(Error::StsObjectNotFound, "File " + fileName + " was not found! "
  25. "Please, specify a full path to the file.");
  26. }
  27. std::string diagnosticKeys =
  28. "{ model m | | Path to the model file. }"
  29. "{ config c | | Path to the model configuration file. }"
  30. "{ framework f | | [Optional] Name of the model framework. }";
  31. int main( int argc, const char** argv )
  32. {
  33. CommandLineParser argParser(argc, argv, diagnosticKeys);
  34. argParser.about("Use this tool to run the diagnostics of provided ONNX/TF model"
  35. "to obtain the information about its support (supported layers).");
  36. if (argc == 1)
  37. {
  38. argParser.printMessage();
  39. return 0;
  40. }
  41. std::string model = checkFileExists(argParser.get<std::string>("model"));
  42. std::string config = checkFileExists(argParser.get<std::string>("config"));
  43. std::string frameworkId = argParser.get<std::string>("framework");
  44. CV_Assert(!model.empty());
  45. enableModelDiagnostics(true);
  46. skipModelImport(true);
  47. redirectError(diagnosticsErrorCallback, NULL);
  48. Net ocvNet = readNet(model, config, frameworkId);
  49. return 0;
  50. }