@@ -10,13 +10,9 @@ namespace HFA {
1010 static bool initDirectML (Ort::SessionOptions &options, int deviceIndex, std::string *errorMessage = nullptr );
1111 static bool initCUDA (Ort::SessionOptions &options, int deviceIndex, std::string *errorMessage = nullptr );
1212
13- HfaModel::HfaModel (const std::filesystem::path &encoder_Path, const std::filesystem::path &predictor_Path,
14- const ExecutionProvider provider, int device_id)
13+ HfaModel::HfaModel (const std::filesystem::path &model_Path, const ExecutionProvider provider, const int device_id)
1514 : m_env(Ort::Env(ORT_LOGGING_LEVEL_WARNING, " HfaModel" )), m_session_options(Ort::SessionOptions()),
16- m_encoder_session (nullptr ), m_predictor_session(nullptr ) {
17-
18- m_input_name = " waveform" ;
19-
15+ m_model_session (nullptr ) {
2016 m_session_options.SetInterOpNumThreads (4 );
2117
2218 // Choose execution provider based on the provided option
@@ -52,20 +48,17 @@ namespace HFA {
5248 // Create ONNX Runtime Session
5349 try {
5450#ifdef _WIN32
55- m_encoder_session = new Ort::Session (m_env, encoder_Path.wstring ().c_str (), m_session_options);
56- m_predictor_session = new Ort::Session (m_env, predictor_Path.wstring ().c_str (), m_session_options);
51+ m_model_session = new Ort::Session (m_env, model_Path.wstring ().c_str (), m_session_options);
5752#else
58- m_encoder_session = new Ort::Session (m_env, encoder_Path.c_str (), m_session_options);
59- m_predictor_session = new Ort::Session (m_env, predictor_Path.c_str (), m_session_options); // Fixed extra dot
53+ m_model_session = new Ort::Session (m_env, model_Path.c_str (), m_session_options);
6054#endif
6155 } catch (const Ort::Exception &e) {
6256 std::cout << " Failed to create session: " << e.what () << std::endl;
6357 }
6458 }
6559
6660 HfaModel::~HfaModel () {
67- delete m_encoder_session;
68- delete m_predictor_session;
61+ delete m_model_session;
6962 m_input_name = {};
7063 }
7164
@@ -76,7 +69,6 @@ namespace HFA {
7669 return false ;
7770 }
7871
79- // 1. 准备输入张量
8072 const size_t batch_size = input_data.size ();
8173 size_t max_len = 0 ;
8274 for (const auto &vec : input_data) {
@@ -87,7 +79,6 @@ namespace HFA {
8779 flattened_input.reserve (batch_size * max_len);
8880 for (const auto &vec : input_data) {
8981 flattened_input.insert (flattened_input.end (), vec.begin (), vec.end ());
90- // 填充不足部分
9182 flattened_input.insert (flattened_input.end (), max_len - vec.size (), 0 .0f );
9283 }
9384
@@ -96,34 +87,13 @@ namespace HFA {
9687 const Ort::Value input_tensor = Ort::Value::CreateTensor<float >(
9788 m_memoryInfo, flattened_input.data (), flattened_input.size (), input_shape.data (), input_shape.size ());
9889
99- // 2. 运行编码器
100- std::vector<Ort::Value> encoder_outputs;
101- try {
102- encoder_outputs = m_encoder_session->Run (Ort::RunOptions{nullptr }, &m_input_name, &input_tensor, 1 ,
103- &m_encoder_output_name, 1 );
104- } catch (const Ort::Exception &e) {
105- msg = " 编码器推理错误: " + std::string (e.what ());
106- return false ;
107- }
108-
109- // 检查编码器输出是否有效
110- if (encoder_outputs.empty ()) {
111- msg = " 编码器未返回输出" ;
112- return false ;
113- }
114-
115- // 3. 运行预测器
11690 try {
117- // 准备输出节点名称
11891 const std::vector<const char *> output_names = {m_predictor_output_name[0 ], m_predictor_output_name[1 ],
11992 m_predictor_output_name[2 ]};
12093
121- // 直接使用编码器输出作为预测器输入
122- auto predictor_outputs =
123- m_predictor_session->Run (Ort::RunOptions{nullptr }, &m_encoder_output_name, &encoder_outputs[0 ], 1 ,
124- output_names.data (), output_names.size ());
94+ auto predictor_outputs = m_model_session->Run (Ort::RunOptions{nullptr }, &m_input_name, &input_tensor, 1 ,
95+ output_names.data (), output_names.size ());
12596
126- // 4. 解析输出结果
12797 // ph_frame_logits [batch, time, classes]
12898 auto parse_3d_output = [](Ort::Value &tensor) {
12999 const auto shape = tensor.GetTensorTypeAndShapeInfo ().GetShape ();
0 commit comments