-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathlocalhost_versioned_run.rs
More file actions
131 lines (112 loc) · 4.02 KB
/
localhost_versioned_run.rs
File metadata and controls
131 lines (112 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use arrow::util::pretty::pretty_format_batches;
use async_trait::async_trait;
use datafusion::common::DataFusionError;
use datafusion::execution::SessionStateBuilder;
use datafusion::prelude::{ParquetReadOptions, SessionContext};
use datafusion_distributed::{
DefaultChannelResolver, DistributedExt, GetWorkerInfoRequest, SessionStateBuilderExt,
WorkerResolver, create_worker_client, display_plan_ascii,
};
use futures::TryStreamExt;
use std::error::Error;
use structopt::StructOpt;
use url::Url;
#[derive(StructOpt)]
#[structopt(
name = "versioned_run",
about = "A localhost Distributed DataFusion runner with worker version filtering"
)]
struct Args {
/// The SQL query to run.
#[structopt()]
query: String,
/// The ports holding Distributed DataFusion workers.
#[structopt(long = "cluster-ports", use_delimiter = true)]
cluster_ports: Vec<u16>,
/// Only use workers reporting this version.
/// When omitted, all workers in --cluster-ports are used.
#[structopt(long)]
version: Option<String>,
/// Whether the distributed plan should be rendered instead of executing the query.
#[structopt(long)]
show_distributed_plan: bool,
}
/// Returns 'true' if the worker at 'url' reports 'expected_version' via
/// the `GetWorkerInfo` RPC. Returns `false` if the worker is unreachable, returns
/// an error, or reports a different version.
async fn worker_has_version(
channel_resolver: &DefaultChannelResolver,
url: &Url,
expected_version: &str,
) -> bool {
let Ok(channel) = channel_resolver.get_channel(url).await else {
return false;
};
let mut client = create_worker_client(channel);
let Ok(response) = client.get_worker_info(GetWorkerInfoRequest {}).await else {
return false;
};
response.into_inner().version == expected_version
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let args = Args::from_args();
let ports = if let Some(target_version) = &args.version {
let channel_resolver = DefaultChannelResolver::default();
let mut compatible = Vec::new();
for &port in &args.cluster_ports {
let url = Url::parse(&format!("http://localhost:{port}"))?;
if worker_has_version(&channel_resolver, &url, target_version).await {
compatible.push(port);
} else {
println!("Excluding worker on port {port} (version mismatch)");
}
}
if compatible.is_empty() {
return Err(format!("No workers matched version '{target_version}'").into());
}
println!(
"Using {}/{} workers matching version '{target_version}'\n",
compatible.len(),
args.cluster_ports.len()
);
compatible
} else {
args.cluster_ports
};
let localhost_resolver = LocalhostWorkerResolver { ports };
let state = SessionStateBuilder::new()
.with_default_features()
.with_distributed_worker_resolver(localhost_resolver)
.with_distributed_planner()
.with_distributed_files_per_task(1)?
.build();
let ctx = SessionContext::from(state);
ctx.register_parquet("weather", "testdata/weather", ParquetReadOptions::default())
.await?;
let df = ctx.sql(&args.query).await?;
if args.show_distributed_plan {
let plan = df.create_physical_plan().await?;
println!("{}", display_plan_ascii(plan.as_ref(), false));
} else {
let stream = df.execute_stream().await?;
let batches = stream.try_collect::<Vec<_>>().await?;
let formatted = pretty_format_batches(&batches)?;
println!("{formatted}");
}
Ok(())
}
#[derive(Clone)]
struct LocalhostWorkerResolver {
ports: Vec<u16>,
}
#[async_trait]
impl WorkerResolver for LocalhostWorkerResolver {
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError> {
Ok(self
.ports
.iter()
.map(|port| Url::parse(&format!("http://localhost:{port}")).unwrap())
.collect())
}
}