@@ -192,7 +192,7 @@ int getPort(int defaultPort) {
192192 @ AutoValue
193193 public abstract static class Read extends PTransform <PBegin , PCollection <Row >> {
194194
195- abstract String host ();
195+ abstract @ Nullable String host ();
196196
197197 abstract int port ();
198198
@@ -252,7 +252,8 @@ public PCollection<Row> expand(PBegin input) {
252252
253253 Schema beamSchema ;
254254 try (BufferAllocator allocator = new RootAllocator (Long .MAX_VALUE );
255- FlightClient client = createClient (allocator , host (), port (), useTls ())) {
255+ FlightClient client =
256+ createClient (allocator , checkNotNull (host (), "host" ), port (), useTls ())) {
256257 FlightInfo info =
257258 client .getInfo (
258259 FlightDescriptor .command (
@@ -283,7 +284,7 @@ HeaderCallOption[] callOptions() {
283284 @ Override
284285 public void populateDisplayData (DisplayData .Builder builder ) {
285286 super .populateDisplayData (builder );
286- builder .add (DisplayData .item ("host" , host ()));
287+ builder .addIfNotNull (DisplayData .item ("host" , host ()));
287288 builder .add (DisplayData .item ("port" , port ()));
288289 builder .add (DisplayData .item ("useTls" , useTls ()));
289290 builder .addIfNotNull (DisplayData .item ("command" , command ()));
@@ -315,15 +316,18 @@ public List<? extends BoundedSource<Row>> split(
315316
316317 List <BoundedSource <Row >> sources = new ArrayList <>();
317318 try (BufferAllocator allocator = new RootAllocator (Long .MAX_VALUE );
318- FlightClient client = createClient (allocator , spec .host (), spec .port (), spec .useTls ())) {
319+ FlightClient client =
320+ createClient (
321+ allocator , checkNotNull (spec .host (), "host" ), spec .port (), spec .useTls ())) {
319322 FlightInfo info =
320323 client .getInfo (
321324 FlightDescriptor .command (
322325 checkNotNull (spec .command (), "command" ).getBytes (StandardCharsets .UTF_8 )),
323326 spec .callOptions ());
324327 for (FlightEndpoint fe : info .getEndpoints ()) {
325328 SerializableEndpoint se =
326- SerializableEndpoint .fromFlightEndpoint (fe , spec .host (), spec .port ());
329+ SerializableEndpoint .fromFlightEndpoint (
330+ fe , checkNotNull (spec .host (), "host" ), spec .port ());
327331 sources .add (new FlightBoundedSource (spec , beamSchema , se ));
328332 }
329333 }
@@ -340,7 +344,9 @@ public long getEstimatedSizeBytes(PipelineOptions options) throws Exception {
340344 return -1 ;
341345 }
342346 try (BufferAllocator allocator = new RootAllocator (Long .MAX_VALUE );
343- FlightClient client = createClient (allocator , spec .host (), spec .port (), spec .useTls ())) {
347+ FlightClient client =
348+ createClient (
349+ allocator , checkNotNull (spec .host (), "host" ), spec .port (), spec .useTls ())) {
344350 FlightInfo info =
345351 client .getInfo (
346352 FlightDescriptor .command (
@@ -388,13 +394,14 @@ public boolean start() throws IOException {
388394 allocator = new RootAllocator (Long .MAX_VALUE );
389395 Read spec = source .spec ;
390396
397+ String hostName = checkNotNull (spec .host (), "host" );
391398 if (source .endpoint != null ) {
392- String host = source .endpoint .getHost (spec . host () );
399+ String host = source .endpoint .getHost (hostName );
393400 int port = source .endpoint .getPort (spec .port ());
394401 client = createClient (allocator , host , port , spec .useTls ());
395402 stream = client .getStream (source .endpoint .getTicket (), spec .callOptions ());
396403 } else {
397- client = createClient (allocator , spec . host () , spec .port (), spec .useTls ());
404+ client = createClient (allocator , hostName , spec .port (), spec .useTls ());
398405 FlightInfo info =
399406 client .getInfo (
400407 FlightDescriptor .command (
@@ -422,7 +429,15 @@ public boolean advance() throws IOException {
422429 if (stream .next ()) {
423430 VectorSchemaRoot root = stream .getRoot ();
424431 if (root .getRowCount () > 0 ) {
425- currentBatchIterator = ArrowConversion .rowsFromRecordBatch (source .beamSchema , root );
432+ Iterator <Row > lazyIterator =
433+ ArrowConversion .rowsFromRecordBatch (source .beamSchema , root );
434+ List <Row > materializedRows = new ArrayList <>();
435+ while (lazyIterator .hasNext ()) {
436+ Row lazyRow = lazyIterator .next ();
437+ materializedRows .add (
438+ Row .withSchema (source .beamSchema ).addValues (lazyRow .getValues ()).build ());
439+ }
440+ currentBatchIterator = materializedRows .iterator ();
426441 }
427442 } else {
428443 return false ;
@@ -472,7 +487,7 @@ public BoundedSource<Row> getCurrentSource() {
472487 @ AutoValue
473488 public abstract static class Write extends PTransform <PCollection <Row >, PDone > {
474489
475- abstract String host ();
490+ abstract @ Nullable String host ();
476491
477492 abstract int port ();
478493
@@ -547,7 +562,7 @@ public PDone expand(PCollection<Row> input) {
547562 @ Override
548563 public void populateDisplayData (DisplayData .Builder builder ) {
549564 super .populateDisplayData (builder );
550- builder .add (DisplayData .item ("host" , host ()));
565+ builder .addIfNotNull (DisplayData .item ("host" , host ()));
551566 builder .add (DisplayData .item ("port" , port ()));
552567 builder .add (DisplayData .item ("useTls" , useTls ()));
553568 builder .addIfNotNull (DisplayData .item ("descriptor" , descriptor ()));
@@ -665,7 +680,6 @@ private void flush() {
665680 }
666681 ensureConnection ();
667682
668- root .setRowCount (batch .size ());
669683 for (int colIdx = 0 ; colIdx < beamSchema .getFieldCount (); colIdx ++) {
670684 FieldVector vector = root .getVector (colIdx );
671685 vector .allocateNew ();
@@ -680,6 +694,7 @@ private void flush() {
680694 }
681695 vector .setValueCount (batch .size ());
682696 }
697+ root .setRowCount (batch .size ());
683698
684699 listener .putNext ();
685700 RECORDS_WRITTEN .inc (batch .size ());
0 commit comments