forked from NVIDIA-NeMo/Curator
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelper.py
More file actions
327 lines (257 loc) · 11.7 KB
/
helper.py
File metadata and controls
327 lines (257 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import io
import json
import math
import os
import tarfile
from functools import partial
from multiprocessing import Pool
from typing import TYPE_CHECKING
import aiohttp
import pandas as pd
from loguru import logger
from PIL import Image
from tqdm import tqdm
if TYPE_CHECKING:
from nemo_curator.tasks import ImageObject
from nemo_curator.tasks.image import ImageBatch
# HTTP status codes
HTTP_OK = 200
async def fetch_image_bytes(session: aiohttp.ClientSession, url: str, retries: int = 3) -> bytes | None:
for attempt in range(1, retries + 1):
try:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=15)) as response:
if response.status == HTTP_OK:
return await response.read()
elif attempt > 1:
logger.debug(f"[Attempt {attempt}] Failed to download {url}: HTTP status {response.status}")
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
if attempt > 1:
logger.debug(f"[Attempt {attempt}] Failed to download {url}: {e}")
if attempt < retries:
await asyncio.sleep(1)
logger.debug(f"All {retries} attempts failed for {url}")
return None
async def process_batch(batch: pd.DataFrame, output_dir: str, batch_num: int) -> None:
tar_filename = os.path.join(output_dir, f"{batch_num:05d}.tar")
metadatas = []
# Set timeout and connection limits for the session
timeout = aiohttp.ClientTimeout(total=15)
connector = aiohttp.TCPConnector(limit=256, limit_per_host=16)
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
tasks = []
for i, (_, row) in enumerate(batch.iterrows()):
caption = row["TEXT"]
url = row["URL"]
key = f"{batch_num:05d}{i:04d}"
meta = {"url": url, "caption": caption, "key": key}
metadatas.append(meta)
tasks.append(fetch_image_bytes(session, url, retries=3))
results = await asyncio.gather(*tasks, return_exceptions=True)
with tarfile.open(tar_filename, "w") as tar:
for i, result in enumerate(results):
# Only proceed for successful downloads (bytes)
if isinstance(result, bytes) and result:
key = f"{batch_num:05d}{i:04d}"
# Add image bytes
jpg_info = tarfile.TarInfo(name=f"{key}.jpg")
jpg_info.size = len(result)
tar.addfile(jpg_info, fileobj=io.BytesIO(result))
# Add caption text
caption_bytes = str(metadatas[i]["caption"]).encode("utf-8")
txt_info = tarfile.TarInfo(name=f"{key}.txt")
txt_info.size = len(caption_bytes)
tar.addfile(txt_info, fileobj=io.BytesIO(caption_bytes))
# Add JSON metadata
json_bytes = json.dumps(metadatas[i]).encode("utf-8")
json_info = tarfile.TarInfo(name=f"{key}.json")
json_info.size = len(json_bytes)
tar.addfile(json_info, fileobj=io.BytesIO(json_bytes))
# Write parquet
meta_df = pd.DataFrame(metadatas)
parquet_path = os.path.join(output_dir, f"{batch_num:05d}.parquet")
meta_df.to_parquet(parquet_path)
def process_parquet_chunk(chunk: tuple[int, pd.DataFrame], output_dir: str) -> None:
batch_num, batch = chunk
asyncio.run(process_batch(batch, output_dir, batch_num))
def download_webdataset(
parquet_path: str,
output_dir: str,
entries_per_tar: int = 10000,
num_processes: int = 2,
) -> None:
os.makedirs(output_dir, exist_ok=True)
# Read the parquet file
df = pd.read_parquet(parquet_path)
print(f"Loaded {len(df)} entries from parquet file")
# Split the dataframe into chunks for multiprocessing
chunks = [
(batch_num, df[i : i + entries_per_tar]) for batch_num, i in enumerate(range(0, len(df), entries_per_tar))
]
print(f"Split into {len(chunks)} chunks of {entries_per_tar} entries each")
# Use multiprocessing to process chunks in parallel with progress tracking
with Pool(processes=num_processes) as pool:
func = partial(process_parquet_chunk, output_dir=output_dir)
# Use tqdm to track progress of chunk processing
list(tqdm(
pool.imap(func, chunks),
total=len(chunks),
desc="Processing chunks",
unit="chunk"
))
# Best-effort cleanup of legacy tmp dir from previous versions
tmp_dir = os.path.join(output_dir, "tmp")
try:
if os.path.isdir(tmp_dir) and not os.listdir(tmp_dir):
os.rmdir(tmp_dir)
except OSError as e:
logger.debug(f"Failed to remove tmp dir {tmp_dir}: {e}")
def _prepare_metadata_record(
image_obj: ImageObject,
new_id: str,
old_id_col: str | None,
) -> dict:
"""Prepare metadata record for an image object."""
metadata_record = {
"id": new_id,
"original_id": image_obj.image_id,
"original_path": image_obj.image_path,
}
# Preserve original ID in specified column if requested
if old_id_col:
metadata_record[old_id_col] = image_obj.image_id
# Add scores and embeddings to metadata
if image_obj.aesthetic_score is not None:
metadata_record["aesthetic_score"] = image_obj.aesthetic_score
if image_obj.nsfw_score is not None:
metadata_record["nsfw_score"] = image_obj.nsfw_score
if image_obj.embedding is not None:
# Convert embedding to list for JSON serialization
metadata_record["embedding"] = image_obj.embedding.tolist()
metadata_record["embedding_dim"] = len(image_obj.embedding)
# Add original metadata
if image_obj.metadata:
metadata_record.update(image_obj.metadata)
return metadata_record
def _add_caption_to_metadata(image_obj: ImageObject, metadata_record: dict) -> None:
"""Add caption/text to metadata record."""
if "caption" in image_obj.metadata:
metadata_record["caption"] = str(image_obj.metadata["caption"])
elif "text" in image_obj.metadata:
metadata_record["caption"] = str(image_obj.metadata["text"])
elif "TEXT" in image_obj.metadata:
metadata_record["caption"] = str(image_obj.metadata["TEXT"])
def _add_image_to_tar(tar: tarfile.TarFile, image_obj: ImageObject, new_id: str) -> None:
"""Add image data to tar file if available."""
if image_obj.image_data is not None:
# Convert numpy array to PIL Image and save as bytes
image_pil = Image.fromarray(image_obj.image_data)
image_bytes = _image_to_bytes(image_pil)
# Add image to tar
image_info = tarfile.TarInfo(name=f"{new_id}.jpg")
image_info.size = len(image_bytes.getvalue())
tar.addfile(image_info, fileobj=image_bytes)
def _add_json_to_tar(tar: tarfile.TarFile, metadata_record: dict, new_id: str) -> None:
"""Add JSON metadata to tar file."""
json_data = json.dumps(metadata_record, indent=2)
json_bytes = json_data.encode("utf-8")
json_info = tarfile.TarInfo(name=f"{new_id}.json")
json_info.size = len(json_bytes)
tar.addfile(json_info, fileobj=io.BytesIO(json_bytes))
def save_imagebatch_to_webdataset(
image_batches: list[ImageBatch],
output_path: str,
samples_per_shard: int = 10000,
max_shards: int = 5,
old_id_col: str | None = None,
) -> None:
"""
Save ImageBatch objects to WebDataset format with resharding.
Args:
image_batches: List of ImageBatch objects from pipeline output
output_path: Directory path where the WebDataset should be saved
samples_per_shard: Number of samples to include in each tar file
max_shards: Order of magnitude of max shards (for zero-padding filenames)
old_id_col: If specified, will preserve the original image_id in this column
"""
os.makedirs(output_path, exist_ok=True)
# Flatten all ImageObjects from all batches
all_image_objects = []
for batch in image_batches:
all_image_objects.extend(batch.data)
if not all_image_objects:
print("No images to save")
return
print(f"Processing {len(all_image_objects)} images into {samples_per_shard} samples per shard")
max_samples_per_shard = math.ceil(math.log10(samples_per_shard))
# Process images in shards
shard_id = 0
for i in range(0, len(all_image_objects), samples_per_shard):
shard_images = all_image_objects[i:i + samples_per_shard]
# Create output file paths
parquet_filename = _name_partition(shard_id, max_shards=max_shards)
tar_filename = _name_partition(shard_id, max_shards=max_shards, ext="tar")
parquet_path = os.path.join(output_path, parquet_filename)
tar_path = os.path.join(output_path, tar_filename)
# Prepare metadata for parquet
metadata_records = []
# Create tar file with images and metadata
with tarfile.open(tar_path, "w") as tar:
for sample_idx, image_obj in enumerate(shard_images):
# Generate new ID combining shard and sample indices
new_id = _combine_id(
shard_id,
sample_idx,
max_shards=max_shards,
max_samples_per_shard=max_samples_per_shard
)
# Prepare metadata record for parquet
metadata_record = _prepare_metadata_record(image_obj, new_id, old_id_col)
metadata_records.append(metadata_record)
# Save image data if available and requested
_add_image_to_tar(tar, image_obj, new_id)
# Store caption/text in metadata (no separate .txt file)
_add_caption_to_metadata(image_obj, metadata_record)
# Add JSON metadata to tar
_add_json_to_tar(tar, metadata_record, new_id)
# Save metadata to parquet
metadata_df = pd.DataFrame(metadata_records)
metadata_df.to_parquet(parquet_path, index=False)
print(f"✓ Saved shard {shard_id:0{max_shards}d} with {len(shard_images)} samples")
print(f" - Tar file: {tar_filename}")
print(f" - Parquet file: {parquet_filename}")
shard_id += 1
print(f"\nSuccessfully saved {len(all_image_objects)} images to {shard_id} shards")
print(f"Output directory: {output_path}")
def _name_partition(
partition_index: int,
max_shards: int = 5,
ext: str = "parquet",
) -> str:
"""Generate partition filename with proper zero-padding."""
return f"{partition_index:0{max_shards}d}.{ext}"
def _combine_id(shard_id: int, sample_id: int, max_shards: int = 5, max_samples_per_shard: int = 4) -> str:
"""Combine shard and sample IDs into a unique identifier."""
int_id = sample_id + (10**max_samples_per_shard) * shard_id
n_digits = max_samples_per_shard + max_shards
return f"{int_id:0{n_digits}d}"
def _image_to_bytes(image_pil: Image.Image, image_format: str = "JPEG") -> io.BytesIO:
"""Convert PIL Image to BytesIO object for tarfile."""
buffer = io.BytesIO()
image_pil.save(buffer, format=image_format)
buffer.seek(0)
return buffer