pixelator_core/common/
io.rs

1use std::fs::File;
2use std::path::Path;
3use std::sync::Arc;
4
5use arrow::array::StringArray;
6use arrow::array::{Array, BooleanArray, PrimitiveArray, RecordBatch, UInt64Array};
7use arrow::compute;
8use arrow::datatypes::{DataType, Field, Schema, SchemaRef, UInt64Type};
9use itertools::Itertools;
10use log::info;
11use parquet::arrow::ArrowWriter;
12use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder};
13
14use crate::common::graph::Graph;
15use crate::common::node_indexing::UmiToNodeIndexMapping;
16use crate::common::node_partitioning::NodePartitioning;
17use crate::common::types::{EdgeWeight, UMI, UMIPair};
18
19/// An iterator that yields UMI pairs from a parquet file one by one.
20/// These can be mapped to node indices and used to construct a graph.
21pub struct ParquetUMIPairIter {
22    reader: ParquetRecordBatchReader,
23
24    expected_size: i64,
25
26    // Buffers for the currently loaded batch
27    // We hold the specific typed arrays to avoid downcasting on every row
28    col_src: Option<PrimitiveArray<UInt64Type>>,
29    col_dst: Option<PrimitiveArray<UInt64Type>>,
30
31    // Pointers for iteration logic
32    current_idx: usize,
33    batch_len: usize,
34}
35
36impl ParquetUMIPairIter {
37    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
38        let file = File::open(path)?;
39
40        let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
41        let num_rows = builder.metadata().file_metadata().num_rows();
42        let reader = builder.with_batch_size(8192).build()?;
43
44        Ok(Self {
45            reader,
46            expected_size: num_rows,
47            col_src: None,
48            col_dst: None,
49            current_idx: 0,
50            batch_len: 0,
51        })
52    }
53
54    fn load_next_batch(&mut self) -> bool {
55        match self.reader.next() {
56            Some(Ok(batch)) => {
57                let src_array = batch
58                    .column_by_name("umi1")
59                    .expect("Could not find umi1 column in data")
60                    .as_any()
61                    .downcast_ref::<PrimitiveArray<UInt64Type>>()
62                    .expect("Column umi1 is not UInt64");
63
64                let dst_array = batch
65                    .column_by_name("umi2")
66                    .expect("Could not find umi2 column in data")
67                    .as_any()
68                    .downcast_ref::<PrimitiveArray<UInt64Type>>()
69                    .expect("Column umi2 is not UInt64");
70
71                self.col_src = Some(src_array.clone());
72                self.col_dst = Some(dst_array.clone());
73
74                self.batch_len = batch.num_rows();
75                self.current_idx = 0;
76                true
77            }
78            _ => false, // End of file or error
79        }
80    }
81}
82
83impl Iterator for ParquetUMIPairIter {
84    type Item = UMIPair;
85
86    fn next(&mut self) -> Option<Self::Item> {
87        // If we are at the end of the current batch, load the next one
88        if self.current_idx >= self.batch_len && !self.load_next_batch() {
89            return None; // No more batches
90        }
91
92        let src = self.col_src.as_ref()?.value(self.current_idx);
93        let dst = self.col_dst.as_ref()?.value(self.current_idx);
94
95        self.current_idx += 1;
96
97        Some((src as UMI, dst as UMI))
98    }
99}
100
101impl ExactSizeIterator for ParquetUMIPairIter {
102    fn len(&self) -> usize {
103        self.expected_size as usize
104    }
105}
106
107pub fn write_record_batches_to_path<P: AsRef<Path>, I>(
108    path: P,
109    schema: SchemaRef,
110    record_batches: I,
111) -> Result<(), Box<dyn std::error::Error>>
112where
113    I: Iterator<Item = RecordBatch>,
114{
115    let file = File::create(path)?;
116    let mut writer = ArrowWriter::try_new(file, schema, None)?;
117
118    for batch in record_batches {
119        writer.write(&batch)?;
120    }
121
122    writer.close()?;
123    Ok(())
124}
125
126/// Remove crossing edges from the provided edge list and write the result to
127/// a new parquet file.
128///
129/// # Arguments
130/// * `input_edgelist_path`: path to the edge list to filtler
131/// * `output_path`: path where the resulting filtered edgelist will be written
132/// * `node_partitioning`: partitioning to be used for detecting crossing edges
133/// * `mapping`: mapping between umis and node indices
134pub fn filter_out_crossing_edges_from_edge_list<PIn: AsRef<Path>, POut: AsRef<Path>, T>(
135    input_edgelist_path: &PIn,
136    output_path: &POut,
137    node_partitioning: &T,
138    mapping: &UmiToNodeIndexMapping,
139) -> Result<(), Box<dyn std::error::Error>>
140where
141    T: NodePartitioning,
142{
143    let file = File::open(input_edgelist_path)?;
144    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
145    let mut fields = builder
146        .schema()
147        .fields()
148        .iter()
149        .cloned()
150        .collect::<Vec<Arc<Field>>>();
151    let mut reader = builder.with_batch_size(8192).build()?;
152
153    let output_file = File::create(output_path)?;
154    fields.push(Arc::new(Field::new("component", DataType::Utf8, true)));
155    let new_schema = Arc::new(Schema::new(fields));
156    let mut writer = ArrowWriter::try_new(output_file, new_schema.clone(), None)?;
157
158    while let Some(Ok(batch)) = reader.next() {
159        let component1_iter = batch
160            .column_by_name("umi1")
161            .expect("Could not find umi1 column in data")
162            .as_any()
163            .downcast_ref::<PrimitiveArray<UInt64Type>>()
164            .expect("Column umi1 is not UInt64")
165            .iter()
166            .map(|umi| {
167                let umi = umi.expect("umi1 column contains null values");
168                node_partitioning
169                    .get_node_to_partition_map()
170                    .get(mapping.map_umi_to_node_index(umi as UMI))
171                    .unwrap_or_else(|| panic!("umi {} not found in umi mapping", umi))
172            });
173
174        let component2_iter = batch
175            .column_by_name("umi2")
176            .expect("Could not find umi2 column in data")
177            .as_any()
178            .downcast_ref::<PrimitiveArray<UInt64Type>>()
179            .expect("Column umi2 is not UInt64")
180            .iter()
181            .map(|umi| {
182                let umi = umi.expect("umi2 column contains null values");
183                node_partitioning
184                    .get_node_to_partition_map()
185                    .get(mapping.map_umi_to_node_index(umi as UMI))
186                    .unwrap_or_else(|| panic!("umi {} not found in umi mapping", umi))
187            });
188
189        let component_col = StringArray::from(
190            component1_iter
191                .zip(component2_iter)
192                .map(|(c1, c2)| if c1 == c2 { Some(c1.to_string()) } else { None })
193                .collect::<Vec<Option<String>>>(),
194        );
195
196        let mask: BooleanArray = component_col.iter().map(|c| c.is_some()).collect();
197        let mut columns = batch.columns().to_vec();
198        columns.push(Arc::new(component_col));
199        let new_batch = RecordBatch::try_new(new_schema.clone(), columns)?;
200
201        let filtered_batch = compute::filter_record_batch(&new_batch, &mask)?;
202
203        if filtered_batch.num_rows() > 0 {
204            writer.write(&filtered_batch)?;
205        }
206    }
207
208    writer.close()?;
209
210    Ok(())
211}
212
213pub fn write_node_partitions_to_parquet<P: AsRef<Path>, T>(
214    path: P,
215    node_partitioning: &T,
216    mapping: &UmiToNodeIndexMapping,
217    batch_size: Option<usize>,
218) -> Result<(), Box<dyn std::error::Error>>
219where
220    T: NodePartitioning,
221{
222    let schema = Arc::new(Schema::new(vec![
223        Field::new("umi", DataType::UInt64, false),
224        Field::new("partition_id", DataType::UInt64, false),
225    ]));
226
227    let mapping_node_to_partition = node_partitioning
228        .get_node_to_partition_map()
229        .iter()
230        .enumerate()
231        .map(|(node_idx, partition_idx)| (mapping.map_node_index_to_umi(node_idx), partition_idx));
232
233    let chunk_size = batch_size.unwrap_or(4096);
234    let chunks = mapping_node_to_partition.chunks(chunk_size);
235
236    let record_batches = chunks.into_iter().map(|chunk| {
237        let mut umis: Vec<u64> = Vec::with_capacity(chunk_size);
238        let mut partitions: Vec<u64> = Vec::with_capacity(chunk_size);
239
240        for (umi, partition) in chunk {
241            umis.push(umi as u64);
242            partitions.push(*partition as u64);
243        }
244
245        let umi_array = Arc::new(UInt64Array::from(umis));
246        let partition_array = Arc::new(UInt64Array::from(partitions));
247
248        RecordBatch::try_new(schema.clone(), vec![umi_array, partition_array])
249            .expect("Failed to build record batch")
250    });
251
252    write_record_batches_to_path(path, schema.clone(), record_batches)
253}
254
255pub fn create_graph_and_umi_mapping_from_parquet_file<T>(
256    parquet_file: &str,
257) -> (UmiToNodeIndexMapping, Graph<T>)
258where
259    T: EdgeWeight,
260{
261    info!("Creating UMI mapping...");
262    let umi_pair_iterator =
263        ParquetUMIPairIter::new(parquet_file).expect("Failed to create ParquetUMIPairIter");
264    let umis: Vec<UMIPair> = umi_pair_iterator.collect();
265
266    let umi_mapping = UmiToNodeIndexMapping::from_umi_pairs(&umis);
267
268    let num_nodes = umi_mapping.get_num_of_nodes();
269    let edges = umi_mapping.map_umi_pair_iterator_to_edge(umis.iter().copied());
270
271    info!("Creating graph...");
272    let graph = Graph::<T>::from_edges(edges, num_nodes);
273    info!(
274        "Graph created with {} nodes, {} edge entries, total edge weight {}",
275        graph.get_num_nodes(),
276        graph.get_edge_entry_count(),
277        graph.get_total_edge_weight()
278    );
279    (umi_mapping, graph)
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    use crate::common::node_partitioning::FastNodePartitioning;
287    use crate::common::types::PartitionId;
288    use itertools::izip;
289    use tempfile::NamedTempFile;
290
291    #[test]
292    fn test_filter_edge_list() {
293        let test_data = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
294            .join("test_data/mix_40cells_1pc_1000rows.parquet");
295        let (umi_mapping, graph) = create_graph_and_umi_mapping_from_parquet_file::<u8>(
296            test_data
297                .to_str()
298                .expect("File name is not a valid UTF-8 string"),
299        );
300        // Make up partitions based on the mod 4 parity of umis
301        let partitioning = FastNodePartitioning::initialize_from_partitions(
302            (0..graph.get_num_nodes())
303                .map(|node_id| umi_mapping.map_node_index_to_umi(node_id) % 4)
304                .collect::<Vec<PartitionId>>(),
305        );
306
307        let output_file = NamedTempFile::new().expect("Failed to create tmp file");
308        let temp_file = std::fs::File::open(output_file.path()).unwrap();
309
310        let _ = filter_out_crossing_edges_from_edge_list(
311            &test_data,
312            &output_file,
313            &partitioning,
314            &umi_mapping,
315        );
316
317        let reader_builder = ParquetRecordBatchReaderBuilder::try_new(temp_file).unwrap();
318        let reader = reader_builder.build().unwrap();
319        assert!(
320            reader
321                .flat_map(|batch| {
322                    let batch = batch.unwrap();
323                    let umi1 = batch
324                        .column_by_name("umi1")
325                        .unwrap()
326                        .as_any()
327                        .downcast_ref::<PrimitiveArray<UInt64Type>>()
328                        .unwrap()
329                        .clone();
330                    let umi2 = batch
331                        .column_by_name("umi2")
332                        .unwrap()
333                        .as_any()
334                        .downcast_ref::<PrimitiveArray<UInt64Type>>()
335                        .unwrap()
336                        .clone();
337                    let component = batch
338                        .column_by_name("component")
339                        .unwrap()
340                        .as_any()
341                        .downcast_ref::<StringArray>()
342                        .unwrap()
343                        .into_iter()
344                        .map(|s| s.unwrap().to_string())
345                        .collect::<Vec<String>>();
346
347                    izip!(umi1.into_iter(), umi2.into_iter(), component.into_iter()).collect::<Vec<(
348                        Option<u64>,
349                        Option<u64>,
350                        String,
351                    )>>(
352                    )
353                })
354                .map(|(umi1, umi2, component)| (umi1.unwrap(), umi2.unwrap(), component))
355                .all(
356                    |(umi1, umi2, component)| (umi1 % 4).to_string() == component
357                        && (umi2 % 4).to_string() == component
358                )
359        );
360    }
361}