@@ -211,16 +211,20 @@ class CompilationResult:
211211 keys : List [str ]
212212
213213 def __post_init__ (self , sequence_examples : List [tf .train .SequenceExample ]):
214- object .__setattr__ (self , 'serialized_sequence_examples' ,
215- [x .SerializeToString () for x in sequence_examples ])
214+ object .__setattr__ (
215+ self , 'serialized_sequence_examples' ,
216+ [x .SerializeToString () for x in sequence_examples if x is not None ])
216217 lengths = [
217218 len (next (iter (x .feature_lists .feature_list .values ())).feature )
218219 for x in sequence_examples
220+ if x is not None
219221 ]
220222 object .__setattr__ (self , 'length' , sum (lengths ))
221223
222- assert (len (self .serialized_sequence_examples ) == len (self .rewards ) ==
223- (len (self .keys )))
224+ # TODO: is it necessary to return keys AND reward_stats(which has the keys)?
225+ # sequence_examples' length could also just not be checked, this allows
226+ # raw_reward_only to do less work
227+ assert (len (sequence_examples ) == len (self .rewards ) == (len (self .keys )))
224228 assert set (self .keys ) == set (self .reward_stats .keys ())
225229 assert not hasattr (self , 'sequence_examples' )
226230
@@ -229,10 +233,9 @@ class CompilationRunnerStub(metaclass=abc.ABCMeta):
229233 """The interface of a stub to CompilationRunner, for type checkers."""
230234
231235 @abc .abstractmethod
232- def collect_data (
233- self , module_spec : corpus .ModuleSpec , tf_policy_path : str ,
234- reward_stat : Optional [Dict [str , RewardStat ]]
235- ) -> WorkerFuture [CompilationResult ]:
236+ def collect_data (self , module_spec : corpus .ModuleSpec , tf_policy_path : str ,
237+ reward_stat : Optional [Dict [str , RewardStat ]],
238+ raw_reward_only : bool ) -> WorkerFuture [CompilationResult ]:
236239 raise NotImplementedError ()
237240
238241 @abc .abstractmethod
@@ -275,17 +278,18 @@ def enable(self):
275278 def cancel_all_work (self ):
276279 self ._cancellation_manager .kill_all_processes ()
277280
278- def collect_data (
279- self , module_spec : corpus .ModuleSpec , tf_policy_path : str ,
280- reward_stat : Optional [Dict [str , RewardStat ]]) -> CompilationResult :
281+ def collect_data (self ,
282+ module_spec : corpus .ModuleSpec ,
283+ tf_policy_path : str ,
284+ reward_stat : Optional [Dict [str , RewardStat ]],
285+ raw_reward_only = False ) -> CompilationResult :
281286 """Collect data for the given IR file and policy.
282287
283288 Args:
284289 module_spec: a ModuleSpec.
285290 tf_policy_path: path to the tensorflow policy.
286291 reward_stat: reward stat of this module, None if unknown.
287- cancellation_token: a CancellationToken through which workers may be
288- signaled early termination
292+ raw_reward_only: whether to return the raw reward value without examples.
289293
290294 Returns:
291295 A CompilationResult. In particular:
@@ -311,7 +315,7 @@ def collect_data(
311315 policy_result = self ._compile_fn (
312316 module_spec ,
313317 tf_policy_path ,
314- reward_only = False ,
318+ reward_only = raw_reward_only ,
315319 cancellation_manager = self ._cancellation_manager )
316320 else :
317321 policy_result = default_result
@@ -326,6 +330,11 @@ def collect_data(
326330 raise ValueError (
327331 (f'Example { k } does not exist under default policy for '
328332 f'module { module_spec .name } ' ))
333+ if raw_reward_only :
334+ sequence_example_list .append (None )
335+ rewards .append (policy_reward )
336+ keys .append (k )
337+ continue
329338 default_reward = reward_stat [k ].default_reward
330339 moving_average_reward = reward_stat [k ].moving_average_reward
331340 sequence_example = _overwrite_trajectory_reward (
0 commit comments