|
| 1 | +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
1 | 15 | import unittest |
| 16 | +from unittest import mock |
2 | 17 |
|
3 | 18 | import paddle |
4 | 19 |
|
5 | | -from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import ( |
6 | | - moe_topk_select, |
7 | | -) |
8 | 20 | from fastdeploy.model_executor.layers.moe.moe import get_moe_scores |
9 | 21 |
|
10 | 22 |
|
@@ -135,15 +147,17 @@ def test_group_topk_using_phi_topk(self): |
135 | 147 | e_score_correction_bias=e_score_correction_bias, |
136 | 148 | ) |
137 | 149 |
|
138 | | - topk_values, topk_idx = moe_topk_select( |
139 | | - gating_output=gating_output, |
140 | | - n_group=n_group, |
141 | | - topk_group=topk_group, |
142 | | - top_k=top_k, |
143 | | - routed_scaling_factor=routed_scaling_factor, |
144 | | - e_score_correction_bias=e_score_correction_bias, |
145 | | - renormalize=renormalize, |
146 | | - ) |
| 150 | + with mock.patch.dict("os.environ", {"FD_USE_PHI_MOE_TOPK": "1"}): |
| 151 | + new_score, topk_values, topk_idx = get_moe_scores( |
| 152 | + gating_output=gating_output, |
| 153 | + n_group=n_group, |
| 154 | + topk_group=topk_group, |
| 155 | + top_k=top_k, |
| 156 | + routed_scaling_factor=routed_scaling_factor, |
| 157 | + e_score_correction_bias=e_score_correction_bias, |
| 158 | + renormalize=renormalize, |
| 159 | + topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20, |
| 160 | + ) |
147 | 161 |
|
148 | 162 | equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item() |
149 | 163 | equal_topk_ids = paddle.allclose( |
|
0 commit comments