@@ -28,6 +28,7 @@ void NpuLmHeadImpl::param_from_args(atb_speed::common::LmHeadParam& param,
2828 const ParallelArgs& parallel_args,
2929 bool isPrefill) {
3030 const bool use_column_parallel = cp_size_ > 1 ;
31+ param.outputHidden = cp_size_ > 1 ;
3132 param.unpadInputs = true ;
3233 param.gatherAhead = isPrefill;
3334 param.hiddenSizePerAttentionHead = args.hidden_size () / args.n_heads ();
@@ -101,7 +102,7 @@ NpuLmHeadImpl::NpuLmHeadImpl(const ModelContext& context) : BaseLayer(context) {
101102 false );
102103
103104 atb_weight_tensors_.resize (1 );
104- atOutTensors_.resize (1 );
105+ atOutTensors_.resize (2 );
105106
106107 auto options = context.get_tensor_options ();
107108 dtype_ = c10::typeMetaToScalarType (options.dtype ());
@@ -146,26 +147,38 @@ int64_t NpuLmHeadImpl::init_node(atb_speed::Model::Node& node,
146147 return -1 ;
147148 }
148149 node.inTensors .resize (node.operation ->GetInputNum ());
149- node.outTensors .resize (1 );
150+ node.outTensors .resize (node. operation -> GetOutputNum () );
150151
151152 node.inTensors .at (1 ) = &atb_weight_tensors_[0 ];
152153
153154 node.variantPack .inTensors .reserve (node.inTensors .size ());
154155 node.variantPack .inTensors .resize (node.inTensors .size ());
155- node.variantPack .outTensors .reserve (1 );
156- node.variantPack .outTensors .resize (1 );
156+ node.variantPack .outTensors .reserve (node. outTensors . size () );
157+ node.variantPack .outTensors .resize (node. outTensors . size () );
157158
158159 return atb::NO_ERROR;
159160}
160161
161162torch::Tensor NpuLmHeadImpl::forward (const torch::Tensor& hidden_states,
162163 const torch::Tensor& seleted_idxes,
163164 int nodeId) {
165+ torch::Tensor out_hidden;
166+ return forward_with_hidden (hidden_states, seleted_idxes, out_hidden, nodeId);
167+ }
168+
169+ torch::Tensor NpuLmHeadImpl::forward_with_hidden (
170+ const torch::Tensor& hidden_states,
171+ const torch::Tensor& seleted_idxes,
172+ torch::Tensor& out_hidden,
173+ int nodeId) {
164174 atb::Status st;
165175 build_node_variant_pack (lm_head_node_prefill_, hidden_states, seleted_idxes);
166176 st = execute_node (lm_head_node_prefill_, nodeId);
167177 LOG_IF (FATAL, st != 0 ) << model_name_
168178 << " execute lmhead node fail, error code: " << st;
179+ if (atOutTensors_.size () > 1 ) {
180+ out_hidden = atOutTensors_[1 ];
181+ }
169182 return atOutTensors_[0 ];
170183}
171184
@@ -212,12 +225,16 @@ void NpuLmHeadImpl::build_node_variant_pack(
212225 inTensorDescs.at (8 ) = placeholder_.desc ;
213226
214227 atb::Status st = node.operation ->InferShape (inTensorDescs, outTensorDescs);
215- at::Tensor newTensor =
216- atb_speed::Utils::CreateAtTensorFromTensorDesc (outTensorDescs.at (0 ));
217-
218- atOutTensors_.at (0 ) = newTensor;
219- node.variantPack .outTensors .at (0 ) =
220- atb_speed::Utils::AtTensor2Tensor (atOutTensors_.at (0 ));
228+ LOG_IF (FATAL, st != atb::NO_ERROR)
229+ << model_name_ << " infer lmhead shape fail, error code: " << st;
230+
231+ atOutTensors_.resize (node.variantPack .outTensors .size ());
232+ for (size_t i = 0 ; i < node.variantPack .outTensors .size (); ++i) {
233+ atOutTensors_.at (i) =
234+ atb_speed::Utils::CreateAtTensorFromTensorDesc (outTensorDescs.at (i));
235+ node.variantPack .outTensors .at (i) =
236+ atb_speed::Utils::AtTensor2Tensor (atOutTensors_.at (i));
237+ }
221238}
222239
223240} // namespace layer
0 commit comments