|
2 | 2 | * @vitest-environment node |
3 | 3 | */ |
4 | 4 | import { beforeEach, describe, expect, it, vi } from 'vitest' |
5 | | -import type { DAG } from '@/executor/dag/builder' |
| 5 | +import type { DAG, DAGNode } from '@/executor/dag/builder' |
6 | 6 | import type { BlockStateWriter, ContextExtensions } from '@/executor/execution/types' |
7 | 7 | import { ParallelOrchestrator } from '@/executor/orchestrators/parallel' |
8 | 8 | import type { ExecutionContext } from '@/executor/types' |
9 | | -import { buildBranchNodeId } from '@/executor/utils/subflow-utils' |
| 9 | +import { |
| 10 | + buildBranchNodeId, |
| 11 | + buildParallelSentinelEndId, |
| 12 | + buildParallelSentinelStartId, |
| 13 | + buildSentinelEndId, |
| 14 | + buildSentinelStartId, |
| 15 | +} from '@/executor/utils/subflow-utils' |
10 | 16 |
|
11 | 17 | const { mockCompactSubflowResults } = vi.hoisted(() => ({ |
12 | 18 | mockCompactSubflowResults: vi.fn(async (results: unknown) => results), |
@@ -34,6 +40,24 @@ function createDag(): DAG { |
34 | 40 | } |
35 | 41 | } |
36 | 42 |
|
| 43 | +function createDagNode(id: string, metadata: DAGNode['metadata'] = {}): DAGNode { |
| 44 | + return { |
| 45 | + id, |
| 46 | + block: { |
| 47 | + id, |
| 48 | + position: { x: 0, y: 0 }, |
| 49 | + config: { tool: '', params: {} }, |
| 50 | + inputs: {}, |
| 51 | + outputs: {}, |
| 52 | + metadata: { id: 'function', name: id }, |
| 53 | + enabled: true, |
| 54 | + }, |
| 55 | + incomingEdges: new Set(), |
| 56 | + outgoingEdges: new Map(), |
| 57 | + metadata, |
| 58 | + } |
| 59 | +} |
| 60 | + |
37 | 61 | function createState(): BlockStateWriter { |
38 | 62 | return { |
39 | 63 | setBlockOutput: vi.fn(), |
@@ -422,6 +446,88 @@ describe('ParallelOrchestrator', () => { |
422 | 446 | expect(dirtySet.has(secondBranchId)).toBe(true) |
423 | 447 | }) |
424 | 448 |
|
| 449 | + it('marks cloned nested loop body nodes dirty for non-zero branches', () => { |
| 450 | + const dag = createDag() |
| 451 | + const parallelId = 'parallel-1' |
| 452 | + const loopId = 'loop-1' |
| 453 | + const taskId = 'task-1' |
| 454 | + const parallelStartId = buildParallelSentinelStartId(parallelId) |
| 455 | + const parallelEndId = buildParallelSentinelEndId(parallelId) |
| 456 | + const loopStartId = buildSentinelStartId(loopId) |
| 457 | + const loopEndId = buildSentinelEndId(loopId) |
| 458 | + |
| 459 | + dag.parallelConfigs.set(parallelId, { |
| 460 | + id: parallelId, |
| 461 | + nodes: [loopId], |
| 462 | + count: 2, |
| 463 | + parallelType: 'count', |
| 464 | + }) |
| 465 | + dag.loopConfigs.set(loopId, { |
| 466 | + id: loopId, |
| 467 | + nodes: [taskId], |
| 468 | + loopType: 'for', |
| 469 | + iterations: 1, |
| 470 | + }) |
| 471 | + dag.nodes.set(parallelStartId, createDagNode(parallelStartId)) |
| 472 | + dag.nodes.set(parallelEndId, createDagNode(parallelEndId)) |
| 473 | + dag.nodes.set( |
| 474 | + loopStartId, |
| 475 | + createDagNode(loopStartId, { |
| 476 | + isSentinel: true, |
| 477 | + sentinelType: 'start', |
| 478 | + subflowId: loopId, |
| 479 | + subflowType: 'loop', |
| 480 | + }) |
| 481 | + ) |
| 482 | + dag.nodes.set( |
| 483 | + taskId, |
| 484 | + createDagNode(taskId, { |
| 485 | + isLoopNode: true, |
| 486 | + subflowId: loopId, |
| 487 | + subflowType: 'loop', |
| 488 | + originalBlockId: taskId, |
| 489 | + }) |
| 490 | + ) |
| 491 | + dag.nodes.set( |
| 492 | + loopEndId, |
| 493 | + createDagNode(loopEndId, { |
| 494 | + isSentinel: true, |
| 495 | + sentinelType: 'end', |
| 496 | + subflowId: loopId, |
| 497 | + subflowType: 'loop', |
| 498 | + }) |
| 499 | + ) |
| 500 | + dag.nodes.get(loopStartId)!.outgoingEdges.set(`${loopStartId}->${taskId}`, { target: taskId }) |
| 501 | + dag.nodes.get(taskId)!.incomingEdges.add(loopStartId) |
| 502 | + dag.nodes.get(taskId)!.outgoingEdges.set(`${taskId}->${loopEndId}`, { target: loopEndId }) |
| 503 | + dag.nodes.get(loopEndId)!.incomingEdges.add(taskId) |
| 504 | + |
| 505 | + const dirtySet = new Set([parallelId]) |
| 506 | + const orchestrator = new ParallelOrchestrator(dag, createState(), null, {}) |
| 507 | + orchestrator.prepareCurrentBatch( |
| 508 | + createContext({ |
| 509 | + runFromBlockContext: { startBlockId: parallelId, dirtySet }, |
| 510 | + parallelExecutions: new Map([ |
| 511 | + [ |
| 512 | + parallelId, |
| 513 | + { |
| 514 | + parallelId, |
| 515 | + totalBranches: 2, |
| 516 | + batchSize: 2, |
| 517 | + currentBatchStart: 0, |
| 518 | + currentBatchSize: 2, |
| 519 | + branchOutputs: new Map(), |
| 520 | + }, |
| 521 | + ], |
| 522 | + ]), |
| 523 | + }), |
| 524 | + parallelId |
| 525 | + ) |
| 526 | + |
| 527 | + expect([...dirtySet]).toContain(taskId) |
| 528 | + expect([...dirtySet].some((nodeId) => nodeId.startsWith(`${taskId}__clone`))).toBe(true) |
| 529 | + }) |
| 530 | + |
425 | 531 | it('compacts accumulated outputs before scheduling later batches', async () => { |
426 | 532 | const dag = createDag() |
427 | 533 | const templateBranchId = buildBranchNodeId('task-1', 0) |
|
0 commit comments