Skip to content

Commit 45d45c5

Browse files
authored
fix: allow SaveModel after Complete() in VowpalWabbitThreadedLearning (#4913)
* fix: allow SaveModel/PerformanceStatistics after Complete() in VowpalWabbitThreadedLearning Previously, calling SaveModel() or accessing PerformanceStatistics after Complete() threw InvalidOperationException because Complete() immediately closed the sync action queue. This forced users into a counter-intuitive pattern of enqueuing saves before signaling completion. Changes: - Move CompleteAdding from Complete() into the root completion continuation using a new atomic CompleteAndRemoveAll(), so sync actions can be enqueued between the Complete() call and the continuation executing - Make SaveModel/PerformanceStatistics detect post-completion state and operate directly on the root VW instance via TryAdd fallback - Add Flush() method to force AllReduce sync on demand without waiting for ExampleCountPerRun threshold Fixes #4911 * fix: replace Task.CompletedTask with Task.FromResult for netstandard2.0 Task.CompletedTask was introduced in .NET 5 and is not available in netstandard2.0 which is the target framework for vw.parallel. * docs: clarify async learning model and usage patterns Improve XML doc comments on VowpalWabbitThreadedLearning to explain: - Learn() enqueues and returns immediately (async dispatch, not blocking) - Typical usage flow with code example (learn, complete, save) - What Complete() guarantees (all examples learned, final allreduce done) - That SaveModel/PerformanceStatistics work synchronously after Complete Addresses feedback from #4911 about the TPL completion model being unclear.
1 parent c35d4dd commit 45d45c5

2 files changed

Lines changed: 319 additions & 18 deletions

File tree

cs/cs_parallel/VowpalWabbitThreadedLearning.cs

Lines changed: 185 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,41 @@ namespace VW
1919
/// <summary>
2020
/// VW wrapper supporting multi-core learning by utilizing thread-based allreduce.
2121
/// </summary>
22+
/// <remarks>
23+
/// <para>
24+
/// This class manages multiple internal VW instances coordinated via allreduce.
25+
/// Each instance runs on its own thread with a bounded work queue. Calls to
26+
/// <see cref="Learn(string)"/> enqueue examples and return immediately — learning
27+
/// happens asynchronously on one of the internal instances (chosen by
28+
/// <see cref="VowpalWabbitSettings.ExampleDistribution"/>). This is not the same as
29+
/// thread-safe concurrent learning; callers submit work from a single thread (or
30+
/// coordinate externally) and the class handles parallelism internally.
31+
/// </para>
32+
/// <para>Typical usage:</para>
33+
/// <code>
34+
/// using (var vw = new VowpalWabbitThreadedLearning(settings))
35+
/// {
36+
/// foreach (var example in examples)
37+
/// vw.Learn(example);
38+
///
39+
/// // Option A: save mid-training by flushing
40+
/// var saveTask = vw.SaveModel("checkpoint.model");
41+
/// await vw.Flush();
42+
/// await saveTask;
43+
///
44+
/// // Option B: save after training is complete
45+
/// await vw.Complete();
46+
/// await vw.SaveModel("final.model");
47+
/// }
48+
/// </code>
49+
/// <para>
50+
/// Weight synchronization (allreduce) occurs automatically every
51+
/// <see cref="VowpalWabbitSettings.ExampleCountPerRun"/> examples, or explicitly
52+
/// via <see cref="Flush"/>. Deferred operations like <see cref="SaveModel()"/> and
53+
/// <see cref="PerformanceStatistics"/> execute at these synchronization points, or
54+
/// can be called directly after <see cref="Complete"/>.
55+
/// </para>
56+
/// </remarks>
2257
public class VowpalWabbitThreadedLearning : IDisposable
2358
{
2459
/// <summary>
@@ -58,6 +93,11 @@ public class VowpalWabbitThreadedLearning : IDisposable
5893
/// </summary>
5994
private Task[] completionTasks;
6095

96+
/// <summary>
97+
/// Combined task tracking whether all completion tasks have finished.
98+
/// </summary>
99+
private Task allCompletedTask;
100+
61101
/// <summary>
62102
/// Number of examples seen sofar. Used by round robin example distributor.
63103
/// </summary>
@@ -147,8 +187,10 @@ public VowpalWabbitThreadedLearning(VowpalWabbitSettings settings)
147187
// perform final AllReduce
148188
vw.EndOfPass();
149189

150-
// execute synchronization actions
151-
foreach (var syncAction in this.syncActions.RemoveAll())
190+
// atomically drain and mark complete — allows sync actions
191+
// (e.g. SaveModel) to be enqueued between Complete() and this
192+
// continuation executing
193+
foreach (var syncAction in this.syncActions.CompleteAndRemoveAll())
152194
{
153195
syncAction(vw);
154196
}
@@ -229,6 +271,39 @@ private uint CheckEndOfPass()
229271
return exampleCount;
230272
}
231273

274+
/// <summary>
275+
/// Forces an AllReduce synchronization and drains all pending sync actions
276+
/// (e.g. <see cref="SaveModel()"/>, <see cref="PerformanceStatistics"/>)
277+
/// without waiting for <see cref="VowpalWabbitSettings.ExampleCountPerRun"/> to be reached.
278+
/// </summary>
279+
/// <returns>Task that completes once the synchronization and all pending sync actions have executed.</returns>
280+
public Task Flush()
281+
{
282+
var completionSource = new TaskCompletionSource<bool>();
283+
284+
this.syncActions.Add(vw => completionSource.SetResult(true));
285+
286+
this.observers[0].OnNext(vw =>
287+
{
288+
// perform AllReduce
289+
vw.EndOfPass();
290+
291+
// execute synchronization actions
292+
foreach (var syncAction in this.syncActions.RemoveAll())
293+
{
294+
syncAction(vw);
295+
}
296+
});
297+
298+
for (int i = 1; i < this.observers.Length; i++)
299+
{
300+
// perform AllReduce
301+
this.observers[i].OnNext(vw => vw.EndOfPass());
302+
}
303+
304+
return completionSource.Task;
305+
}
306+
232307
/// <summary>
233308
/// Enqueues an action to be executed on one of vw instances.
234309
/// </summary>
@@ -279,6 +354,12 @@ internal Task<T> Post<T>(Func<VowpalWabbit, T> func)
279354
/// Learns from the given example.
280355
/// </summary>
281356
/// <param name="line">The example to learn.</param>
357+
/// <remarks>
358+
/// This method enqueues the example for asynchronous learning on one of the
359+
/// internal VW instances and returns immediately. The example string is captured
360+
/// by the work item and must not be mutated after this call. To ensure all
361+
/// enqueued learning is complete, call <see cref="Complete"/> or <see cref="Flush"/>.
362+
/// </remarks>
282363
public void Learn(string line)
283364
{
284365
Debug.Assert(line != null);
@@ -287,9 +368,15 @@ public void Learn(string line)
287368
}
288369

289370
/// <summary>
290-
/// Learns from the given example.
371+
/// Learns from the given multi-line example.
291372
/// </summary>
292373
/// <param name="lines">The multi-line example to learn.</param>
374+
/// <remarks>
375+
/// This method enqueues the example for asynchronous learning on one of the
376+
/// internal VW instances and returns immediately. The lines are captured
377+
/// by the work item and must not be mutated after this call. To ensure all
378+
/// enqueued learning is complete, call <see cref="Complete"/> or <see cref="Flush"/>.
379+
/// </remarks>
293380
public void Learn(IEnumerable<string> lines)
294381
{
295382
Debug.Assert(lines != null);
@@ -300,69 +387,116 @@ public void Learn(IEnumerable<string> lines)
300387
/// <summary>
301388
/// Synchronized performance statistics.
302389
/// </summary>
303-
/// <remarks>The task is only completed after synchronization of all instances, triggered <see cref="VowpalWabbitSettings.ExampleCountPerRun"/> example.</remarks>
390+
/// <remarks>
391+
/// Can be accessed before or after <see cref="Complete"/>. If accessed after completion,
392+
/// returns statistics directly from the root VW instance.
393+
/// </remarks>
304394
public Task<VowpalWabbitPerformanceStatistics> PerformanceStatistics
305395
{
306396
get
307397
{
398+
if (this.allCompletedTask != null && this.allCompletedTask.IsCompleted)
399+
{
400+
return Task.FromResult(this.vws[0].PerformanceStatistics);
401+
}
402+
308403
var completionSource = new TaskCompletionSource<VowpalWabbitPerformanceStatistics>();
309404

310-
this.syncActions.Add(vw => completionSource.SetResult(vw.PerformanceStatistics));
405+
if (!this.syncActions.TryAdd(vw => completionSource.SetResult(vw.PerformanceStatistics)))
406+
{
407+
return Task.FromResult(this.vws[0].PerformanceStatistics);
408+
}
311409

312410
return completionSource.Task;
313411
}
314412
}
315413

316414
/// <summary>
317-
/// Signal that no more examples are send.
415+
/// Signals that no more examples will be submitted.
318416
/// </summary>
319-
/// <returns>Task completes once the learning and cleanup is done.</returns>
417+
/// <returns>Task that completes once all enqueued examples have been learned
418+
/// and a final allreduce synchronization has been performed.</returns>
419+
/// <remarks>
420+
/// After awaiting this task, the model is fully trained and methods like
421+
/// <see cref="SaveModel()"/> and <see cref="PerformanceStatistics"/> can be
422+
/// called synchronously (they execute immediately on the root VW instance
423+
/// rather than being deferred to a sync point).
424+
/// </remarks>
320425
public Task Complete()
321426
{
322-
// make sure no more sync actions are added, which might otherwise never been called
323-
this.syncActions.CompleteAdding();
324-
325427
foreach (var actionBlock in this.actionBlocks)
326428
{
327429
actionBlock.Complete();
328430
}
329431

330-
return Task.WhenAll(this.completionTasks);
331-
432+
this.allCompletedTask = Task.WhenAll(this.completionTasks);
433+
return this.allCompletedTask;
332434
}
333435

334436
/// <summary>
335437
/// Saves a model as part of the synchronization.
336438
/// </summary>
337-
/// <returns>Task compeletes once the model is saved.</returns>
439+
/// <remarks>
440+
/// Can be called before or after <see cref="Complete"/>. If called after completion,
441+
/// the model is saved directly on the root VW instance. If called before, the save is
442+
/// deferred until the next synchronization point or completion.
443+
/// </remarks>
444+
/// <returns>Task that completes once the model is saved.</returns>
338445
public Task SaveModel()
339446
{
447+
if (this.allCompletedTask != null && this.allCompletedTask.IsCompleted)
448+
{
449+
this.vws[0].SaveModel();
450+
return Task.FromResult(true);
451+
}
452+
340453
var completionSource = new TaskCompletionSource<bool>();
341454

342-
this.syncActions.Add(vw =>
455+
if (!this.syncActions.TryAdd(vw =>
343456
{
344457
vw.SaveModel();
345458
completionSource.SetResult(true);
346-
});
459+
}))
460+
{
461+
// sync actions were already drained and marked complete
462+
this.vws[0].SaveModel();
463+
return Task.FromResult(true);
464+
}
347465

348466
return completionSource.Task;
349467
}
350468

351469
/// <summary>
352470
/// Saves a model as part of the synchronization.
353471
/// </summary>
354-
/// <returns>Task compeletes once the model is saved.</returns>
472+
/// <remarks>
473+
/// Can be called before or after <see cref="Complete"/>. If called after completion,
474+
/// the model is saved directly on the root VW instance. If called before, the save is
475+
/// deferred until the next synchronization point or completion.
476+
/// </remarks>
477+
/// <returns>Task that completes once the model is saved.</returns>
355478
public Task SaveModel(string filename)
356479
{
357480
Debug.Assert(!string.IsNullOrEmpty(filename));
358481

482+
if (this.allCompletedTask != null && this.allCompletedTask.IsCompleted)
483+
{
484+
this.vws[0].SaveModel(filename);
485+
return Task.FromResult(true);
486+
}
487+
359488
var completionSource = new TaskCompletionSource<bool>();
360489

361-
this.syncActions.Add(vw =>
490+
if (!this.syncActions.TryAdd(vw =>
362491
{
363492
vw.SaveModel(filename);
364493
completionSource.SetResult(true);
365-
});
494+
}))
495+
{
496+
// sync actions were already drained and marked complete
497+
this.vws[0].SaveModel(filename);
498+
return Task.FromResult(true);
499+
}
366500

367501
return completionSource.Task;
368502
}
@@ -444,6 +578,23 @@ public void Add(T item)
444578
}
445579
}
446580

581+
/// <summary>
582+
/// Tries to add an object to the end of the list.
583+
/// </summary>
584+
/// <param name="item">The object to be added to the list.</param>
585+
/// <returns>True if the item was added; false if the list has been marked complete.</returns>
586+
public bool TryAdd(T item)
587+
{
588+
lock (this.lockObject)
589+
{
590+
if (completed)
591+
return false;
592+
593+
this.items.Add(item);
594+
return true;
595+
}
596+
}
597+
447598
/// <summary>
448599
/// Marks this list as complete. Any subsequent calls to <see cref="Add"/> will trigger an <see cref="InvalidOperationException"/>.
449600
/// </summary>
@@ -455,6 +606,22 @@ public void CompleteAdding()
455606
}
456607
}
457608

609+
/// <summary>
610+
/// Atomically marks this list as complete and removes all elements.
611+
/// </summary>
612+
/// <returns>The elements removed.</returns>
613+
public T[] CompleteAndRemoveAll()
614+
{
615+
lock (this.lockObject)
616+
{
617+
this.completed = true;
618+
var ret = this.items.ToArray();
619+
this.items.Clear();
620+
621+
return ret;
622+
}
623+
}
624+
458625
/// <summary>
459626
/// Removes all elements from the list.
460627
/// </summary>

0 commit comments

Comments
 (0)