@@ -258,13 +258,13 @@ def run(argv=None, test_pipeline=None):
258258 _ = (
259259 token_lists
260260 | 'MLTransformInput' >> beam .Map (lambda tokens : {'tokens' : tokens })
261- | 'ApplyMLTransform' >> MLTransform (
262- write_artifact_location = artifact_location ).with_transform (
263- ComputeAndApplyVocabulary (
264- columns = ['tokens' ],
265- top_k = known_args .vocab_size ,
266- frequency_threshold = known_args .min_frequency ,
267- vocab_filename = 'vocab' ))
261+ | 'ApplyMLTransform' >>
262+ MLTransform ( write_artifact_location = artifact_location ).with_transform (
263+ ComputeAndApplyVocabulary (
264+ columns = ['tokens' ],
265+ top_k = known_args .vocab_size ,
266+ frequency_threshold = known_args .min_frequency ,
267+ vocab_filename = 'vocab' ))
268268 | 'ExtractTransformedTokens' >> beam .Map (lambda row : row .tokens )
269269 | 'FlattenTokens' >> beam .FlatMap (list )
270270 | 'DropEmptyTokens' >> beam .Filter (bool ))
@@ -278,8 +278,8 @@ def run(argv=None, test_pipeline=None):
278278 vocab_filename = 'vocab' ,
279279 column_name = 'tokens' ))
280280 output_tokens = [known_args .oov_token ]
281- output_tokens .extend (token for token in vocab_tokens
282- if token != known_args .oov_token )
281+ output_tokens .extend (
282+ token for token in vocab_tokens if token != known_args .oov_token )
283283 if len (output_tokens ) == 1 :
284284 logging .warning (
285285 'No tokens remained after filtering; writing only reserved token %r.' ,
0 commit comments