Skip to content

Commit 2b99b88

Browse files
committed
update hfa version3 - power
1 parent 583a335 commit 2b99b88

File tree

6 files changed

+19
-61
lines changed

6 files changed

+19
-61
lines changed

src/apps/HubertFA/gui/MainWindow.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,6 @@ namespace HFA {
5555
return false;
5656
}
5757

58-
YAML::Node config = YAML::LoadFile(config_file.string());
59-
const fs::path encoder_path = model_path / (config["hubert_config"]["encoder"].as<std::string>() + "-" +
60-
config["hubert_config"]["channel"].as<std::string>() + ".onnx");
61-
if (!fs::exists(encoder_path)) {
62-
error = encoder_path.string() + " does not exist";
63-
return false;
64-
}
65-
6658
YAML::Node vocab = YAML::LoadFile(vocab_file.string());
6759
const YAML::Node &dictionaries = vocab["dictionaries"];
6860
if (dictionaries) {
@@ -234,7 +226,7 @@ namespace HFA {
234226

235227
connect(remove, &QPushButton::clicked, this, &MainWindow::slot_removeListItem);
236228
connect(clear, &QPushButton::clicked, this, &MainWindow::slot_clearTaskList);
237-
connect(runHfa, &QPushButton::clicked, this, &MainWindow::slot_runFbl);
229+
connect(runHfa, &QPushButton::clicked, this, &MainWindow::slot_runHfa);
238230

239231
connect(btnOutTg, &QPushButton::clicked, this, &MainWindow::slot_outTgPath);
240232

@@ -349,7 +341,7 @@ namespace HFA {
349341
taskList->clear();
350342
}
351343

352-
void MainWindow::slot_runFbl() {
344+
void MainWindow::slot_runHfa() {
353345
out->clear();
354346
m_threadpool->clear();
355347

@@ -422,7 +414,7 @@ namespace HFA {
422414
}
423415

424416
void MainWindow::slot_threadFinished() {
425-
const auto msg = QString("Fbl complete! Total: %3, Success: %1, Failed: %2")
417+
const auto msg = QString("Hfa complete! Total: %3, Success: %1, Failed: %2")
426418
.arg(m_workTotal - m_workError)
427419
.arg(m_workError)
428420
.arg(m_workTotal);

src/apps/HubertFA/gui/MainWindow.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ namespace HFA {
7878

7979
void slot_removeListItem() const;
8080
void slot_clearTaskList() const;
81-
void slot_runFbl();
81+
void slot_runHfa();
8282

8383
void slot_oneFailed(const QString &filename, const QString &msg);
8484
void slot_oneFinished(const QString &filename, const QString &msg);

src/apps/HubertFA/util/AlignmentDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ namespace HFA {
4444
const std::vector<std::string> &class_names,
4545
const std::map<std::string, float> &mel_spec_config)
4646
: vocab_(vocab), non_speech_phs_(class_names) {
47-
melspec_config_.hop_length = mel_spec_config.at("hop_length");
47+
melspec_config_.hop_length = mel_spec_config.at("hop_size");
4848
melspec_config_.sample_rate = mel_spec_config.at("sample_rate");
4949
frame_length_ = static_cast<float>(melspec_config_.hop_length) / melspec_config_.sample_rate;
5050
}

src/apps/HubertFA/util/Hfa.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,8 @@ namespace HFA {
2727
const auto melspec_config = config["melspec_config"].as<std::map<std::string, float>>();
2828
hfa_input_sample_rate = static_cast<int>(melspec_config.find("sample_rate")->second);
2929

30-
const fs::path encoder_path = model_folder / (config["hubert_config"]["encoder"].as<std::string>() + "-" +
31-
config["hubert_config"]["channel"].as<std::string>() + ".onnx");
32-
const fs::path predictor_path = model_folder / "model.onnx";
33-
34-
m_hfa = std::make_unique<HfaModel>(encoder_path, predictor_path, provider, device_id);
30+
const fs::path model_path = model_folder / "model.onnx";
31+
m_hfa = std::make_unique<HfaModel>(model_path, provider, device_id);
3532

3633
const fs::path vocab_file = model_folder / "vocab.yaml";
3734
YAML::Node vocab = YAML::LoadFile(vocab_file.string());

src/apps/HubertFA/util/HfaModel.cpp

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,9 @@ namespace HFA {
1010
static bool initDirectML(Ort::SessionOptions &options, int deviceIndex, std::string *errorMessage = nullptr);
1111
static bool initCUDA(Ort::SessionOptions &options, int deviceIndex, std::string *errorMessage = nullptr);
1212

13-
HfaModel::HfaModel(const std::filesystem::path &encoder_Path, const std::filesystem::path &predictor_Path,
14-
const ExecutionProvider provider, int device_id)
13+
HfaModel::HfaModel(const std::filesystem::path &model_Path, const ExecutionProvider provider, const int device_id)
1514
: m_env(Ort::Env(ORT_LOGGING_LEVEL_WARNING, "HfaModel")), m_session_options(Ort::SessionOptions()),
16-
m_encoder_session(nullptr), m_predictor_session(nullptr) {
17-
18-
m_input_name = "waveform";
19-
15+
m_model_session(nullptr) {
2016
m_session_options.SetInterOpNumThreads(4);
2117

2218
// Choose execution provider based on the provided option
@@ -52,20 +48,17 @@ namespace HFA {
5248
// Create ONNX Runtime Session
5349
try {
5450
#ifdef _WIN32
55-
m_encoder_session = new Ort::Session(m_env, encoder_Path.wstring().c_str(), m_session_options);
56-
m_predictor_session = new Ort::Session(m_env, predictor_Path.wstring().c_str(), m_session_options);
51+
m_model_session = new Ort::Session(m_env, model_Path.wstring().c_str(), m_session_options);
5752
#else
58-
m_encoder_session = new Ort::Session(m_env, encoder_Path.c_str(), m_session_options);
59-
m_predictor_session = new Ort::Session(m_env, predictor_Path.c_str(), m_session_options); // Fixed extra dot
53+
m_model_session = new Ort::Session(m_env, model_Path.c_str(), m_session_options);
6054
#endif
6155
} catch (const Ort::Exception &e) {
6256
std::cout << "Failed to create session: " << e.what() << std::endl;
6357
}
6458
}
6559

6660
HfaModel::~HfaModel() {
67-
delete m_encoder_session;
68-
delete m_predictor_session;
61+
delete m_model_session;
6962
m_input_name = {};
7063
}
7164

@@ -76,7 +69,6 @@ namespace HFA {
7669
return false;
7770
}
7871

79-
// 1. 准备输入张量
8072
const size_t batch_size = input_data.size();
8173
size_t max_len = 0;
8274
for (const auto &vec : input_data) {
@@ -87,7 +79,6 @@ namespace HFA {
8779
flattened_input.reserve(batch_size * max_len);
8880
for (const auto &vec : input_data) {
8981
flattened_input.insert(flattened_input.end(), vec.begin(), vec.end());
90-
// 填充不足部分
9182
flattened_input.insert(flattened_input.end(), max_len - vec.size(), 0.0f);
9283
}
9384

@@ -96,34 +87,13 @@ namespace HFA {
9687
const Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
9788
m_memoryInfo, flattened_input.data(), flattened_input.size(), input_shape.data(), input_shape.size());
9889

99-
// 2. 运行编码器
100-
std::vector<Ort::Value> encoder_outputs;
101-
try {
102-
encoder_outputs = m_encoder_session->Run(Ort::RunOptions{nullptr}, &m_input_name, &input_tensor, 1,
103-
&m_encoder_output_name, 1);
104-
} catch (const Ort::Exception &e) {
105-
msg = "编码器推理错误: " + std::string(e.what());
106-
return false;
107-
}
108-
109-
// 检查编码器输出是否有效
110-
if (encoder_outputs.empty()) {
111-
msg = "编码器未返回输出";
112-
return false;
113-
}
114-
115-
// 3. 运行预测器
11690
try {
117-
// 准备输出节点名称
11891
const std::vector<const char *> output_names = {m_predictor_output_name[0], m_predictor_output_name[1],
11992
m_predictor_output_name[2]};
12093

121-
// 直接使用编码器输出作为预测器输入
122-
auto predictor_outputs =
123-
m_predictor_session->Run(Ort::RunOptions{nullptr}, &m_encoder_output_name, &encoder_outputs[0], 1,
124-
output_names.data(), output_names.size());
94+
auto predictor_outputs = m_model_session->Run(Ort::RunOptions{nullptr}, &m_input_name, &input_tensor, 1,
95+
output_names.data(), output_names.size());
12596

126-
// 4. 解析输出结果
12797
// ph_frame_logits [batch, time, classes]
12898
auto parse_3d_output = [](Ort::Value &tensor) {
12999
const auto shape = tensor.GetTensorTypeAndShapeInfo().GetShape();

src/apps/HubertFA/util/HfaModel.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,18 @@ namespace HFA {
1717

1818
class HfaModel {
1919
public:
20-
explicit HfaModel(const std::filesystem::path &encoder_Path, const std::filesystem::path &predictor_Path, ExecutionProvider provider, int device_id);
20+
explicit HfaModel(const std::filesystem::path &model_Path, ExecutionProvider provider, int device_id);
2121
~HfaModel();
2222
bool forward(const std::vector<std::vector<float>> &input_data, HfaLogits &result, std::string &msg) const;
2323

2424
private:
2525
Ort::Env m_env;
2626
Ort::SessionOptions m_session_options;
27-
Ort::Session *m_encoder_session;
28-
Ort::Session *m_predictor_session;
27+
Ort::Session *m_model_session;
2928
Ort::AllocatorWithDefaultOptions m_allocator;
30-
const char *m_input_name;
31-
const char *m_encoder_output_name = "input_feature";
32-
const char* m_predictor_output_name[3] = {"ph_frame_logits", "ph_edge_logits", "cvnt_logits"};
29+
30+
const char *m_input_name = "waveform";
31+
const char *m_predictor_output_name[3] = {"ph_frame_logits", "ph_edge_logits", "cvnt_logits"};
3332

3433
#ifdef _WIN_X86
3534
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);

0 commit comments

Comments
 (0)