355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550 | def insert_upsert_implementation_patched(
path,
table,
file,
pk,
flatten,
nl,
csv,
tsv,
lines,
text,
convert,
imports,
delimiter,
quotechar,
sniff,
no_headers,
encoding,
batch_size,
alter,
upsert,
ignore=False,
replace=False,
truncate=False,
not_null=None,
default=None,
detect_types=None,
analyze=False,
load_extension=None,
silent=False,
bulk_sql=None,
functions=None,
):
"""Patched version of the insert/upsert implementation from the sqlite-utils package."""
import rich_click as click
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
if functions:
_register_functions(db, functions)
if (delimiter or quotechar or sniff or no_headers) and not tsv:
csv = True
if (nl + csv + tsv) >= 2:
raise click.ClickException("Use just one of --nl, --csv or --tsv")
if (csv or tsv) and flatten:
raise click.ClickException("--flatten cannot be used with --csv or --tsv")
if encoding and not (csv or tsv):
raise click.ClickException("--encoding must be used with --csv or --tsv")
if pk and len(pk) == 1:
pk = pk[0]
encoding = encoding or "utf-8-sig"
# The --sniff option needs us to buffer the file to peek ahead
sniff_buffer = None
if sniff:
sniff_buffer = io.BufferedReader(file, buffer_size=4096)
decoded = io.TextIOWrapper(sniff_buffer, encoding=encoding)
else:
decoded = io.TextIOWrapper(file, encoding=encoding)
try:
tracker = None
with file_progress(decoded, silent=silent) as decoded:
if csv or tsv:
if sniff:
# Read first 2048 bytes and use that to detect
first_bytes = sniff_buffer.peek(2048)
dialect = csv_std.Sniffer().sniff(
first_bytes.decode(encoding, "ignore")
)
else:
dialect = "excel-tab" if tsv else "excel"
csv_reader_args = {"dialect": dialect}
if delimiter:
csv_reader_args["delimiter"] = delimiter
if quotechar:
csv_reader_args["quotechar"] = quotechar
reader = csv_std.reader(decoded, **csv_reader_args)
first_row = next(reader)
if no_headers:
headers = [
"untitled_{}".format(i + 1) for i in range(len(first_row))
]
reader = itertools.chain([first_row], reader)
else:
headers = first_row
docs = (dict(zip(headers, row)) for row in reader)
if detect_types:
tracker = TypeTracker()
docs = tracker.wrap(docs)
elif lines:
docs = ({"line": line.strip()} for line in decoded)
elif text:
docs = ({"text": decoded.read()},)
else:
try:
if nl:
docs = (json.loads(line) for line in decoded if line.strip())
else:
docs = json.load(decoded)
if isinstance(docs, dict):
docs = [docs]
except json.decoder.JSONDecodeError:
raise click.ClickException(
"Invalid JSON - use --csv for CSV or --tsv for TSV files"
)
if flatten:
docs = (_flatten(doc) for doc in docs)
if convert:
variable = "row"
if lines:
variable = "line"
elif text:
variable = "text"
fn = _compile_code(convert, imports, variable=variable)
if lines:
docs = (fn(doc["line"]) for doc in docs)
elif text:
# Special case: this is allowed to be an iterable
text_value = list(docs)[0]["text"]
fn_return = fn(text_value)
if isinstance(fn_return, dict):
docs = [fn_return]
else:
try:
docs = iter(fn_return)
except TypeError:
raise click.ClickException(
"--convert must return dict or iterator"
)
else:
docs = (fn(doc) or doc for doc in docs)
extra_kwargs = {
"ignore": ignore,
"replace": replace,
"truncate": truncate,
"analyze": analyze,
}
if not_null:
extra_kwargs["not_null"] = set(not_null)
if default:
extra_kwargs["defaults"] = dict(default)
if upsert:
extra_kwargs["upsert"] = upsert
# docs should all be dictionaries
docs = (verify_is_dict(doc) for doc in docs)
# Apply {"$base64": true, ...} decoding, if needed
docs = (decode_base64_values(doc) for doc in docs)
# For bulk_sql= we use cursor.executemany() instead
if bulk_sql:
if batch_size:
doc_chunks = chunks(docs, batch_size)
else:
doc_chunks = [docs]
for doc_chunk in doc_chunks:
with db.conn:
db.conn.cursor().executemany(bulk_sql, doc_chunk)
return
try:
db[table].insert_all(
docs, pk=pk, batch_size=batch_size, alter=alter, **extra_kwargs
)
except Exception as e:
if (
isinstance(e, OperationalError)
and e.args
and "has no column named" in e.args[0]
):
raise click.ClickException(
"{}\n\nTry using --alter to add additional columns".format(
e.args[0]
)
)
# If we can find sql= and parameters= arguments, show those
variables = _find_variables(e.__traceback__, ["sql", "parameters"])
if "sql" in variables and "parameters" in variables:
raise click.ClickException(
"{}\n\nsql = {}\nparameters = {}".format(
str(e), variables["sql"], variables["parameters"]
)
)
else:
raise e
if tracker is not None:
db[table].transform(types=tracker.types)
finally:
decoded.close()
|