Source code for dnarecords.writer

"""
DNARecords available writers.
"""


# pylint: disable=too-many-instance-attributes
# pylint: disable=too-few-public-methods
# It is reasonable in this case.
[docs]class DNARecordsWriter: """Genomics data (vcf, bgen, etc.) to tfrecords or parquet. **Sample and variant wise**. Core class to go from genomics data to tfrecords or parquet files ready to use with Deep Learning frameworks like Tensorflow or Pytorch. * Able to generate **dna records variant wise or sample wise (i.e. transposing the matrix)**. * **Takes advantage of sparsity** (very convenient to save space and computation, specially with Deep Learning). * **Scales automatically to any sized dataset. Tested on UKBB**. Example -------- .. code-block:: python import dnarecords as dr hl = dr.helper.DNARecordsUtils.init_hail() hl.utils.get_1kg('/tmp/1kg') mt = hl.read_matrix_table('/tmp/1kg/1kg.mt') mt = mt.annotate_entries(dosage=hl.pl_dosage(mt.PL)) output = '/tmp/dnarecords/output' writer = dr.writer.DNARecordsWriter(mt.dosage) writer.write(output, sparse=True, sample_wise=True, variant_wise=True, tfrecord_format=True, parquet_format=True, write_mode='overwrite', gzip=True) reader = dr.reader.DNASparkReader(output) reader.sample_wise_dnarecords().show(2) reader.variant_wise_dnarecords().show(2) .. code-block:: text +---+--------------------+--------------------+----------------+ |key| chr1_indices| chr1_values|chr1_dense_shape| ... +---+--------------------+--------------------+----------------+ | 26|[0, 2, 4, 5, 6, 7...|[0.33607214002352...| 909| ... | 29|[0, 1, 2, 3, 4, 5...|[0.20076008098505...| 909| ... +---+--------------------+--------------------+----------------+ only showing top 1 row +--------------------+--------------------+----+-----------+ | indices| values| key|dense_shape| +--------------------+--------------------+----+-----------+ |[0, 1, 2, 3, 4, 5...|[0.9984177, 0.007...|3506| 10880| |[0, 1, 2, 3, 4, 5...|[0.11181577, 0.01...|3764| 10880| +--------------------+--------------------+----+-----------+ only showing top 2 rows ... :param expr: a Hail expression. Currently, ony expressions coercible to numeric are supported :param block_size: entries per block in the internal operations :param staging: path to staging directory to use for intermediate data. Default: /tmp/dnarecords/staging. """ from typing import TYPE_CHECKING if TYPE_CHECKING: from hail import MatrixTable, Expression from pyspark.sql import DataFrame _i_blocks: set _j_blocks: set _nrows: int _ncols: int _sparsity: float _chrom_ranges: dict _mt: 'MatrixTable' _skeys: 'DataFrame' _vkeys: 'DataFrame' def __init__(self, expr: 'Expression', block_size: int = int(1e6), staging: str = '/tmp/dnarecords/staging'): self._assert_expr_type(expr) self._expr = expr self._block_size = block_size self._kv_blocks_path = f'{staging}/kv-blocks' self._vw_dna_staging = f'{staging}/vw-dnaparquet' self._sw_dna_staging = f'{staging}/sw-dnaparquet' @staticmethod def _assert_expr_type(expr): import hail.expr.expressions.expression_typecheck as tc msg = 'Only expr_numeric are supported at this time.' msg += '\nConsider redefining the entry_expr as an expr_numeric.' msg += '\nFeel free to use the helper functions on this module.' if not tc.expr_numeric.can_coerce(expr.dtype): raise Exception(msg) def _set_mt(self): from hail.expr.expressions.expression_utils import matrix_table_source self._mt = matrix_table_source('dnarecords', self._expr).annotate_entries(v=self._expr) def _index_mt(self): self._mt = self._mt.add_row_index('i') self._mt = self._mt.key_rows_by('i') self._mt = self._mt.add_col_index('j') self._mt = self._mt.key_cols_by('j') def _set_vkeys_skeys(self): self._vkeys = self._mt.key_rows_by().rows().to_spark().withColumnRenamed('i', 'key').cache() self._skeys = self._mt.key_cols_by().cols().to_spark().withColumnRenamed('j', 'key').cache() def _set_chrom_ranges(self): from pyspark.sql import functions as F gdf = self._vkeys.select('`locus.contig`', 'key').groupby('`locus.contig`') gdf = gdf.agg(F.min('key').alias('start'), F.max('key').alias('end')) self._chrom_ranges = {r['locus.contig'].replace('chr', ''): [r['start'], r['end']] for i, r in gdf.toPandas().iterrows()} def _update_vkeys_by_chrom_ranges(self): from dnarecords.helper import DNARecordsUtils import pyspark.sql.functions as F chr_start_keys = [[k, start] for k, [start, _] in self._chrom_ranges.items()] spark = DNARecordsUtils.spark_session() chr_start_keys_df = spark.sparkContext.parallelize(chr_start_keys).toDF(['locus.contig', 'chr_start_key']) self._vkeys = self._vkeys \ .join(chr_start_keys_df, on='locus.contig', how='left') \ .withColumn('chr_key', F.col('key') - F.col('chr_start_key')) def _select_ijv(self): self._mt = self._mt.select_globals().select_rows().select_cols().select_entries('v') def _filter_out_undefined_entries(self): from dnarecords.helper import DNARecordsUtils hl = DNARecordsUtils.init_hail() self._mt = self._mt.filter_entries(hl.is_defined(self._mt.v)) def _filter_out_zeroes(self): from dnarecords.helper import DNARecordsUtils hl = DNARecordsUtils.init_hail() self._mt = self._mt.filter_entries(0 != hl.coalesce(self._mt.v, 0)) def _set_max_nrows_ncols(self): self._nrows = self._mt.count_rows() self._ncols = self._mt.count_cols() def _set_sparsity(self): mts = self._mt.head(10000, None) entries = mts.key_cols_by().key_rows_by().entries().to_spark().filter('v is not null').count() self._sparsity = entries / (mts.count_rows() * mts.count_cols()) def _get_block_size(self): import math M, N, S = self._nrows + 1, self._ncols + 1, self._sparsity + 1e-6 B = self._block_size / S m = math.ceil(math.sqrt(B * M / N)) n = math.ceil(math.sqrt(B * N / M)) return m, n def _build_ij_blocks(self): import pyspark.sql.functions as F m, n = self._get_block_size() df = self._mt.key_cols_by().key_rows_by().entries().to_spark().filter('v is not null') df = df.withColumn('ib', F.expr(f"i div {m}")) df = df.withColumn('jb', F.expr(f"j div {n}")) df.repartition('ib', 'jb').write.partitionBy('ib', 'jb').mode('overwrite').parquet(self._kv_blocks_path) def _set_ij_blocks(self): import re from dnarecords.helper import DNARecordsUtils hl = DNARecordsUtils.init_hail() all_blocks = [p for p in hl.hadoop_ls(f'{self._kv_blocks_path}/*') if p['is_dir']] self._i_blocks = {re.search(r'ib=(\d+)', p['path']).group(1) for p in all_blocks} self._j_blocks = {re.search(r'jb=(\d+)', p['path']).group(1) for p in all_blocks} @staticmethod def _dnarecord_variant_agg(rows, mkey): for row in rows: sd = dict(sorted(row['data'].items())) yield {'key': row['i'], 'indices': list(sd.keys()), 'values': list(sd.values()), 'dense_shape': mkey + 1} @staticmethod def _dnarecord_sample_agg(rows, chrom_ranges): import numpy as np for row in rows: sd = dict(sorted(row['data'].items())) keys = np.array(list(sd.keys())) vals = np.array(list(sd.values())) result = {'key': row['j']} for locus, [start, end] in chrom_ranges.items(): mask = (start <= keys) & (keys <= end) result.update({f'chr{locus}_indices': (keys[mask] - start).tolist(), f'chr{locus}_values': vals[mask].tolist(), f'chr{locus}_dense_shape': end - start + 1}) yield result @staticmethod def _to_dnarecord_variant_wise(mkey): return lambda rows: DNARecordsWriter._dnarecord_variant_agg(rows, mkey) @staticmethod def _to_dnarecord_sample_wise(chrom_ranges): return lambda rows: DNARecordsWriter._dnarecord_sample_agg(rows, chrom_ranges) def _build_dna_block_variant_wise(self, blocks_path, output): from dnarecords.helper import DNARecordsUtils import pyspark.sql.functions as F def get_dna_schema(gdf): import pyspark.sql.types as pytypes data_type = [s for s in gdf.schema if s.name == 'data'][0].dataType values_type = data_type.valueType return pytypes.StructType([pytypes.StructField("key", pytypes.LongType(), False), pytypes.StructField("indices", pytypes.ArrayType(pytypes.LongType(), False), False), pytypes.StructField("values", pytypes.ArrayType(values_type, False), False), pytypes.StructField("dense_shape", pytypes.LongType(), False)]) spark = DNARecordsUtils.spark_session() df = spark.read.parquet(blocks_path).select('i', 'j', 'v') df = df.groupBy('i').agg(F.map_from_entries(F.collect_list(F.struct('j', 'v'))).alias('data')) schema = get_dna_schema(df) mapper = self._to_dnarecord_variant_wise(self._ncols - 1) df.rdd.mapPartitions(mapper).toDF(schema).repartition(1).write.mode('overwrite').parquet(output) def _build_dna_block_sample_wise(self, blocks_path, output, chrom_ranges): from dnarecords.helper import DNARecordsUtils import pyspark.sql.functions as F def get_dna_schema(gdf): import pyspark.sql.types as pytypes data_type = [s for s in gdf.schema if s.name == 'data'][0].dataType values_type = data_type.valueType gdf_schema = pytypes.StructType([pytypes.StructField("key", pytypes.LongType(), False)]) for k in chrom_ranges.keys(): gdf_schema.add(f"chr{k}_indices", pytypes.ArrayType(pytypes.LongType(), False), False) gdf_schema.add(f"chr{k}_values", pytypes.ArrayType(values_type, False), False) gdf_schema.add(f"chr{k}_dense_shape", pytypes.LongType(), False) return gdf_schema spark = DNARecordsUtils.spark_session() df = spark.read.parquet(blocks_path).select('i', 'j', 'v') df = df.groupBy('j').agg(F.map_from_entries(F.collect_list(F.struct('i', 'v'))).alias('data')) schema = get_dna_schema(df) mapper = self._to_dnarecord_sample_wise(chrom_ranges) df.rdd.mapPartitions(mapper).toDF(schema).repartition(1).write.mode('overwrite').parquet(output) def _build_dna_blocks(self, by): from multiprocessing.pool import ThreadPool import multiprocessing as ms pool = ThreadPool(ms.cpu_count() - 1) if by == 'i': params = [ [f'{self._kv_blocks_path}/ib={ib}/jb=*', f'{self._vw_dna_staging}/{ib}-of-{len(self._i_blocks):04}'] for ib in self._i_blocks] pool.starmap(self._build_dna_block_variant_wise, params) if by == 'j': params = [ [f'{self._kv_blocks_path}/ib=*/jb={jb}', f'{self._sw_dna_staging}/{jb}-of-{len(self._j_blocks):04}', self._chrom_ranges] for jb in self._j_blocks] pool.starmap(self._build_dna_block_sample_wise, params) pool.close() pool.join() # pylint: disable=too-many-arguments # It is reasonable in this case. @staticmethod def _write_dnarecords(output, output_schema, dna_blocks, write_mode, gzip, tfrecord_format): from dnarecords.helper import DNARecordsUtils spark = DNARecordsUtils.spark_session() df = spark.read.parquet(dna_blocks) df_writer = df.write.mode(write_mode) if tfrecord_format: df_writer = df_writer.format("tfrecord").option("recordType", "Example") if gzip: # Needs huge overhead memory df_writer = df_writer.option("codec", "org.apache.hadoop.io.compress.GzipCodec") else: df_writer = df_writer.format('parquet') if gzip: df_writer = df_writer.option("compression", "gzip") df_writer.save(output) sc_writer = spark.read.json(spark.sparkContext.parallelize([df.schema.json()])).coalesce(1).write sc_writer.mode(write_mode).format('json').save(output_schema) @staticmethod def _write_key_files(source, output, tfrecord_format, write_mode): from dnarecords.helper import DNARecordsUtils import pyspark.sql.functions as F spark = DNARecordsUtils.spark_session() if tfrecord_format: reader = spark.read.format("tfrecord").option("recordType", "Example") else: reader = spark.read.format("parquet") df = reader.load(source).withColumn("path", F.regexp_extract(F.input_file_name(), f"(.*){source}/(.*)", 2)) df.select('key', 'path').write.mode(write_mode).parquet(output) # pylint: disable=too-many-arguments # It is reasonable in this case.
[docs] def write(self, output: str, sparse: bool = True, sample_wise: bool = True, variant_wise: bool = False, tfrecord_format: bool = True, parquet_format: bool = False, write_mode: str = 'error', gzip: bool = True) -> None: """DNARecords spark writer. Writes a DNARecords dataset based on the Hail `expr` provided in the class constructor. :rtype: Dict[str, DataFrame] :param output: path to the output location of the DNARecords. :param sparse: generate sparse data (filtering out any zero values). Default: True. :param sample_wise: generate DNARecords in a sample wise fashion (i.e. transposing the matrix, one column -> one record). Default: True. :param variant_wise: generate DNARecords in a variant wise fashion (i.e. one row -> one record) Default: False. :param tfrecord_format: generate tfrecords output files. Default: True. :param parquet_format: generate parquet output files. Default: False. :param write_mode: spark write mode parameter ('error', 'overwrite', etc.). Default: 'error'. :param gzip: gzip the output files. Default: True. :return: A dictionary with DataFrames for each generated output. See Also -------- :obj:`.DNARecordsUtils.dnarecords_tree`, :obj:`.DNASparkReader` """ from dnarecords.helper import DNARecordsUtils if not sample_wise and not variant_wise: raise Exception('At least one of sample_wise, variant_wise must be True') if not tfrecord_format and not parquet_format: raise Exception('At least one of tfrecord_format, parquet_format must be True') otree = DNARecordsUtils.dnarecords_tree(output) self._set_mt() self._index_mt() self._set_vkeys_skeys() self._set_chrom_ranges() self._update_vkeys_by_chrom_ranges() self._vkeys.write.mode(write_mode).parquet(otree['vkeys']) self._skeys.write.mode(write_mode).parquet(otree['skeys']) self._select_ijv() self._filter_out_undefined_entries() if sparse: self._filter_out_zeroes() self._set_max_nrows_ncols() self._set_sparsity() self._build_ij_blocks() self._set_ij_blocks() if variant_wise: self._build_dna_blocks('i') if sample_wise: self._build_dna_blocks('j') if variant_wise: if tfrecord_format: self._write_dnarecords(otree['vwrec'], otree['vwrsc'], f'{self._vw_dna_staging}/*', write_mode, gzip, True) self._write_key_files(otree['vwrec'], otree['vwrfs'], True, write_mode) if parquet_format: self._write_dnarecords(otree['vwpar'], otree['vwpsc'], f'{self._vw_dna_staging}/*', write_mode, gzip, False) self._write_key_files(otree['vwpar'], otree['vwpfs'], False, write_mode) if sample_wise: if tfrecord_format: self._write_dnarecords(otree['swrec'], otree['swrsc'], f'{self._sw_dna_staging}/*', write_mode, gzip, True) self._write_key_files(otree['swrec'], otree['swrfs'], True, write_mode) if parquet_format: self._write_dnarecords(otree['swpar'], otree['swpsc'], f'{self._sw_dna_staging}/*', write_mode, gzip, False) self._write_key_files(otree['swpar'], otree['swpfs'], False, write_mode)