dragonfly_client_rs/
client.rs

1mod methods;
2mod models;
3
4use chrono::{DateTime, TimeDelta, Utc};
5use flate2::read::GzDecoder;
6pub use methods::*;
7pub use models::*;
8use tempfile::{tempdir, tempfile, TempDir};
9
10use color_eyre::Result;
11use reqwest::{blocking::Client, Url};
12use std::{io, time::Duration};
13use tracing::{error, info, trace, warn};
14
15pub struct AuthState {
16    pub access_token: String,
17    pub expires_at: DateTime<Utc>,
18}
19
20pub struct RulesState {
21    pub rules: yara::Rules,
22    pub hash: String,
23}
24
25#[warn(clippy::module_name_repetitions)]
26pub struct DragonflyClient {
27    pub client: Client,
28    pub authentication_state: AuthState,
29    pub rules_state: RulesState,
30}
31
32impl DragonflyClient {
33    pub fn new() -> Result<Self> {
34        let client = Client::builder().gzip(true).build()?;
35
36        let auth_response = fetch_access_token(&client)?;
37        let rules_response = fetch_rules(&client, &auth_response.access_token)?;
38
39        let authentication_state = AuthState {
40            access_token: auth_response.access_token,
41            expires_at: Utc::now() + TimeDelta::seconds(auth_response.expires_in.into()),
42        };
43
44        let rules_state = RulesState {
45            rules: rules_response.compile()?,
46            hash: rules_response.hash,
47        };
48
49        Ok(Self {
50            client,
51            authentication_state,
52            rules_state,
53        })
54    }
55
56    /// Update the state with a new access token, if it's expired.
57    ///
58    /// If the token is not expired, then nothing is done.
59    /// If an error occurs while reauthenticating, the function retries with an exponential backoff
60    /// described by the equation `min(10 * 60, 2^(x - 1))` where `x` is the number of failed tries.
61    pub fn reauthenticate(&mut self) {
62        if Utc::now() <= self.authentication_state.expires_at {
63            return;
64        }
65
66        let base = 2_f64;
67        let initial_timeout = 1_f64;
68        let mut tries = 0;
69
70        let authentication_response = loop {
71            let r = fetch_access_token(self.get_http_client());
72            match r {
73                Ok(authentication_response) => break authentication_response,
74                Err(e) => {
75                    let sleep_time = if tries < 10 {
76                        let t = initial_timeout * base.powf(f64::from(tries));
77                        warn!("Failed to reauthenticate after {tries} tries! Error: {e:#?}. Trying again in {t:.3} seconds");
78                        t
79                    } else {
80                        error!("Failed to reauthenticate after {tries} tries! Error: {e:#?}. Trying again in 600.000 seconds");
81                        600_f64
82                    };
83
84                    std::thread::sleep(Duration::from_secs_f64(sleep_time));
85                    tries += 1;
86                }
87            }
88        };
89
90        trace!("Successfully got new access token!");
91
92        self.authentication_state = AuthState {
93            access_token: authentication_response.access_token,
94            expires_at: Utc::now() + TimeDelta::seconds(authentication_response.expires_in.into()),
95        };
96
97        info!("Successfully reauthenticated.");
98    }
99
100    /// Update the global ruleset. Waits for a write lock.
101    pub fn update_rules(&mut self) -> Result<()> {
102        self.reauthenticate();
103
104        let response = fetch_rules(
105            self.get_http_client(),
106            &self.authentication_state.access_token,
107        )?;
108        self.rules_state.rules = response.compile()?;
109        self.rules_state.hash = response.hash;
110
111        Ok(())
112    }
113
114    pub fn bulk_get_job(&mut self, n_jobs: usize) -> reqwest::Result<Vec<Job>> {
115        self.reauthenticate();
116
117        fetch_bulk_job(
118            self.get_http_client(),
119            &self.authentication_state.access_token,
120            n_jobs,
121        )
122    }
123
124    pub fn get_job(&mut self) -> reqwest::Result<Option<Job>> {
125        self.reauthenticate();
126
127        // not `slice::first` because we want to own the Job
128        self.bulk_get_job(1).map(|jobs| jobs.into_iter().nth(0))
129    }
130
131    /// Send a [`crate::client::models::ScanResult`] to mainframe
132    pub fn send_result(&mut self, body: models::ScanResult) -> reqwest::Result<()> {
133        self.reauthenticate();
134
135        send_result(
136            self.get_http_client(),
137            &self.authentication_state.access_token,
138            body,
139        )
140    }
141
142    /// Return a reference to the underlying HTTP Client
143    pub fn get_http_client(&self) -> &Client {
144        &self.client
145    }
146}
147
148/// Download and unpack a tarball, return the [`TempDir`] containing the contents.
149fn extract_tarball<R: io::Read>(response: R) -> Result<TempDir> {
150    let mut tarball = tar::Archive::new(GzDecoder::new(response));
151    let tmpdir = tempdir()?;
152    tarball.unpack(tmpdir.path())?;
153    Ok(tmpdir)
154}
155
156/// Download and extract a zip, return the [`TempDir`] containing the contents.
157fn extract_zipfile<R: io::Read>(mut response: R) -> Result<TempDir> {
158    let mut file = tempfile()?;
159
160    // first write the archive to a file because `response` isn't Seek, which is needed by
161    // `zip::ZipArchive::new`
162    io::copy(&mut response, &mut file)?;
163
164    let mut zip = zip::ZipArchive::new(file)?;
165    let tmpdir = tempdir()?;
166    zip.extract(tmpdir.path())?;
167
168    Ok(tmpdir)
169}
170
171pub fn download_distribution(http_client: &Client, download_url: Url) -> Result<TempDir> {
172    // This conversion is fast as per the docs
173    let is_tarball = download_url.as_str().ends_with(".tar.gz");
174    let response = http_client.get(download_url).send()?;
175
176    if is_tarball {
177        extract_tarball(response)
178    } else {
179        extract_zipfile(response)
180    }
181}