1use std::fs::File;
2use std::path::Path;
3use std::sync::Arc;
4
5use arrow::array::StringArray;
6use arrow::array::{Array, BooleanArray, PrimitiveArray, RecordBatch, UInt64Array};
7use arrow::compute;
8use arrow::datatypes::{DataType, Field, Schema, SchemaRef, UInt64Type};
9use itertools::Itertools;
10use log::info;
11use parquet::arrow::ArrowWriter;
12use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder};
13
14use crate::common::graph::Graph;
15use crate::common::node_indexing::UmiToNodeIndexMapping;
16use crate::common::node_partitioning::NodePartitioning;
17use crate::common::types::{EdgeWeight, UMI, UMIPair};
18
19pub struct ParquetUMIPairIter {
22 reader: ParquetRecordBatchReader,
23
24 expected_size: i64,
25
26 col_src: Option<PrimitiveArray<UInt64Type>>,
29 col_dst: Option<PrimitiveArray<UInt64Type>>,
30
31 current_idx: usize,
33 batch_len: usize,
34}
35
36impl ParquetUMIPairIter {
37 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
38 let file = File::open(path)?;
39
40 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
41 let num_rows = builder.metadata().file_metadata().num_rows();
42 let reader = builder.with_batch_size(8192).build()?;
43
44 Ok(Self {
45 reader,
46 expected_size: num_rows,
47 col_src: None,
48 col_dst: None,
49 current_idx: 0,
50 batch_len: 0,
51 })
52 }
53
54 fn load_next_batch(&mut self) -> bool {
55 match self.reader.next() {
56 Some(Ok(batch)) => {
57 let src_array = batch
58 .column_by_name("umi1")
59 .expect("Could not find umi1 column in data")
60 .as_any()
61 .downcast_ref::<PrimitiveArray<UInt64Type>>()
62 .expect("Column umi1 is not UInt64");
63
64 let dst_array = batch
65 .column_by_name("umi2")
66 .expect("Could not find umi2 column in data")
67 .as_any()
68 .downcast_ref::<PrimitiveArray<UInt64Type>>()
69 .expect("Column umi2 is not UInt64");
70
71 self.col_src = Some(src_array.clone());
72 self.col_dst = Some(dst_array.clone());
73
74 self.batch_len = batch.num_rows();
75 self.current_idx = 0;
76 true
77 }
78 _ => false, }
80 }
81}
82
83impl Iterator for ParquetUMIPairIter {
84 type Item = UMIPair;
85
86 fn next(&mut self) -> Option<Self::Item> {
87 if self.current_idx >= self.batch_len && !self.load_next_batch() {
89 return None; }
91
92 let src = self.col_src.as_ref()?.value(self.current_idx);
93 let dst = self.col_dst.as_ref()?.value(self.current_idx);
94
95 self.current_idx += 1;
96
97 Some((src as UMI, dst as UMI))
98 }
99}
100
101impl ExactSizeIterator for ParquetUMIPairIter {
102 fn len(&self) -> usize {
103 self.expected_size as usize
104 }
105}
106
107pub fn write_record_batches_to_path<P: AsRef<Path>, I>(
108 path: P,
109 schema: SchemaRef,
110 record_batches: I,
111) -> Result<(), Box<dyn std::error::Error>>
112where
113 I: Iterator<Item = RecordBatch>,
114{
115 let file = File::create(path)?;
116 let mut writer = ArrowWriter::try_new(file, schema, None)?;
117
118 for batch in record_batches {
119 writer.write(&batch)?;
120 }
121
122 writer.close()?;
123 Ok(())
124}
125
126pub fn filter_out_crossing_edges_from_edge_list<PIn: AsRef<Path>, POut: AsRef<Path>, T>(
135 input_edgelist_path: &PIn,
136 output_path: &POut,
137 node_partitioning: &T,
138 mapping: &UmiToNodeIndexMapping,
139) -> Result<(), Box<dyn std::error::Error>>
140where
141 T: NodePartitioning,
142{
143 let file = File::open(input_edgelist_path)?;
144 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
145 let mut fields = builder
146 .schema()
147 .fields()
148 .iter()
149 .cloned()
150 .collect::<Vec<Arc<Field>>>();
151 let mut reader = builder.with_batch_size(8192).build()?;
152
153 let output_file = File::create(output_path)?;
154 fields.push(Arc::new(Field::new("component", DataType::Utf8, true)));
155 let new_schema = Arc::new(Schema::new(fields));
156 let mut writer = ArrowWriter::try_new(output_file, new_schema.clone(), None)?;
157
158 while let Some(Ok(batch)) = reader.next() {
159 let component1_iter = batch
160 .column_by_name("umi1")
161 .expect("Could not find umi1 column in data")
162 .as_any()
163 .downcast_ref::<PrimitiveArray<UInt64Type>>()
164 .expect("Column umi1 is not UInt64")
165 .iter()
166 .map(|umi| {
167 let umi = umi.expect("umi1 column contains null values");
168 node_partitioning
169 .get_node_to_partition_map()
170 .get(mapping.map_umi_to_node_index(umi as UMI))
171 .unwrap_or_else(|| panic!("umi {} not found in umi mapping", umi))
172 });
173
174 let component2_iter = batch
175 .column_by_name("umi2")
176 .expect("Could not find umi2 column in data")
177 .as_any()
178 .downcast_ref::<PrimitiveArray<UInt64Type>>()
179 .expect("Column umi2 is not UInt64")
180 .iter()
181 .map(|umi| {
182 let umi = umi.expect("umi2 column contains null values");
183 node_partitioning
184 .get_node_to_partition_map()
185 .get(mapping.map_umi_to_node_index(umi as UMI))
186 .unwrap_or_else(|| panic!("umi {} not found in umi mapping", umi))
187 });
188
189 let component_col = StringArray::from(
190 component1_iter
191 .zip(component2_iter)
192 .map(|(c1, c2)| if c1 == c2 { Some(c1.to_string()) } else { None })
193 .collect::<Vec<Option<String>>>(),
194 );
195
196 let mask: BooleanArray = component_col.iter().map(|c| c.is_some()).collect();
197 let mut columns = batch.columns().to_vec();
198 columns.push(Arc::new(component_col));
199 let new_batch = RecordBatch::try_new(new_schema.clone(), columns)?;
200
201 let filtered_batch = compute::filter_record_batch(&new_batch, &mask)?;
202
203 if filtered_batch.num_rows() > 0 {
204 writer.write(&filtered_batch)?;
205 }
206 }
207
208 writer.close()?;
209
210 Ok(())
211}
212
213pub fn write_node_partitions_to_parquet<P: AsRef<Path>, T>(
214 path: P,
215 node_partitioning: &T,
216 mapping: &UmiToNodeIndexMapping,
217 batch_size: Option<usize>,
218) -> Result<(), Box<dyn std::error::Error>>
219where
220 T: NodePartitioning,
221{
222 let schema = Arc::new(Schema::new(vec![
223 Field::new("umi", DataType::UInt64, false),
224 Field::new("partition_id", DataType::UInt64, false),
225 ]));
226
227 let mapping_node_to_partition = node_partitioning
228 .get_node_to_partition_map()
229 .iter()
230 .enumerate()
231 .map(|(node_idx, partition_idx)| (mapping.map_node_index_to_umi(node_idx), partition_idx));
232
233 let chunk_size = batch_size.unwrap_or(4096);
234 let chunks = mapping_node_to_partition.chunks(chunk_size);
235
236 let record_batches = chunks.into_iter().map(|chunk| {
237 let mut umis: Vec<u64> = Vec::with_capacity(chunk_size);
238 let mut partitions: Vec<u64> = Vec::with_capacity(chunk_size);
239
240 for (umi, partition) in chunk {
241 umis.push(umi as u64);
242 partitions.push(*partition as u64);
243 }
244
245 let umi_array = Arc::new(UInt64Array::from(umis));
246 let partition_array = Arc::new(UInt64Array::from(partitions));
247
248 RecordBatch::try_new(schema.clone(), vec![umi_array, partition_array])
249 .expect("Failed to build record batch")
250 });
251
252 write_record_batches_to_path(path, schema.clone(), record_batches)
253}
254
255pub fn create_graph_and_umi_mapping_from_parquet_file<T>(
256 parquet_file: &str,
257) -> (UmiToNodeIndexMapping, Graph<T>)
258where
259 T: EdgeWeight,
260{
261 info!("Creating UMI mapping...");
262 let umi_pair_iterator =
263 ParquetUMIPairIter::new(parquet_file).expect("Failed to create ParquetUMIPairIter");
264 let umis: Vec<UMIPair> = umi_pair_iterator.collect();
265
266 let umi_mapping = UmiToNodeIndexMapping::from_umi_pairs(&umis);
267
268 let num_nodes = umi_mapping.get_num_of_nodes();
269 let edges = umi_mapping.map_umi_pair_iterator_to_edge(umis.iter().copied());
270
271 info!("Creating graph...");
272 let graph = Graph::<T>::from_edges(edges, num_nodes);
273 info!(
274 "Graph created with {} nodes, {} edge entries, total edge weight {}",
275 graph.get_num_nodes(),
276 graph.get_edge_entry_count(),
277 graph.get_total_edge_weight()
278 );
279 (umi_mapping, graph)
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 use crate::common::node_partitioning::FastNodePartitioning;
287 use crate::common::types::PartitionId;
288 use itertools::izip;
289 use tempfile::NamedTempFile;
290
291 #[test]
292 fn test_filter_edge_list() {
293 let test_data = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
294 .join("test_data/mix_40cells_1pc_1000rows.parquet");
295 let (umi_mapping, graph) = create_graph_and_umi_mapping_from_parquet_file::<u8>(
296 test_data
297 .to_str()
298 .expect("File name is not a valid UTF-8 string"),
299 );
300 let partitioning = FastNodePartitioning::initialize_from_partitions(
302 (0..graph.get_num_nodes())
303 .map(|node_id| umi_mapping.map_node_index_to_umi(node_id) % 4)
304 .collect::<Vec<PartitionId>>(),
305 );
306
307 let output_file = NamedTempFile::new().expect("Failed to create tmp file");
308 let temp_file = std::fs::File::open(output_file.path()).unwrap();
309
310 let _ = filter_out_crossing_edges_from_edge_list(
311 &test_data,
312 &output_file,
313 &partitioning,
314 &umi_mapping,
315 );
316
317 let reader_builder = ParquetRecordBatchReaderBuilder::try_new(temp_file).unwrap();
318 let reader = reader_builder.build().unwrap();
319 assert!(
320 reader
321 .flat_map(|batch| {
322 let batch = batch.unwrap();
323 let umi1 = batch
324 .column_by_name("umi1")
325 .unwrap()
326 .as_any()
327 .downcast_ref::<PrimitiveArray<UInt64Type>>()
328 .unwrap()
329 .clone();
330 let umi2 = batch
331 .column_by_name("umi2")
332 .unwrap()
333 .as_any()
334 .downcast_ref::<PrimitiveArray<UInt64Type>>()
335 .unwrap()
336 .clone();
337 let component = batch
338 .column_by_name("component")
339 .unwrap()
340 .as_any()
341 .downcast_ref::<StringArray>()
342 .unwrap()
343 .into_iter()
344 .map(|s| s.unwrap().to_string())
345 .collect::<Vec<String>>();
346
347 izip!(umi1.into_iter(), umi2.into_iter(), component.into_iter()).collect::<Vec<(
348 Option<u64>,
349 Option<u64>,
350 String,
351 )>>(
352 )
353 })
354 .map(|(umi1, umi2, component)| (umi1.unwrap(), umi2.unwrap(), component))
355 .all(
356 |(umi1, umi2, component)| (umi1 % 4).to_string() == component
357 && (umi2 % 4).to_string() == component
358 )
359 );
360 }
361}