1use std::path::PathBuf;
2use std::{collections::HashSet, path::Path};
3
4use color_eyre::Result;
5use reqwest::{blocking::Client, Url};
6use tempfile::TempDir;
7use walkdir::WalkDir;
8use yara::Rules;
9
10use crate::{
11 client::{download_distribution, Job, SubmitJobResultsSuccess},
12 exts::RuleExt,
13 utils::create_inspector_url,
14};
15
16#[derive(Debug, Hash, Eq, PartialEq, Clone)]
17pub struct RuleScore {
18 pub name: String,
19 pub score: i64,
20}
21
22#[derive(Debug)]
24pub struct FileScanResult {
25 pub path: PathBuf,
26 pub rules: Vec<RuleScore>,
27}
28
29impl FileScanResult {
30 fn new(path: PathBuf, rules: Vec<RuleScore>) -> Self {
31 Self { path, rules }
32 }
33
34 fn calculate_score(&self) -> i64 {
36 self.rules.iter().map(|i| i.score).sum()
37 }
38}
39
40struct Distribution {
42 dir: TempDir,
43 inspector_url: Url,
44}
45
46impl Distribution {
47 fn scan(&mut self, rules: &Rules) -> Result<DistributionScanResults> {
48 let mut file_scan_results: Vec<FileScanResult> = Vec::new();
49 for entry in WalkDir::new(self.dir.path())
50 .into_iter()
51 .filter_map(|dirent| dirent.into_iter().find(|de| de.file_type().is_file()))
52 {
53 let file_scan_result = self.scan_file(entry.path(), rules)?;
54 file_scan_results.push(file_scan_result);
55 }
56
57 Ok(DistributionScanResults::new(
58 file_scan_results,
59 self.inspector_url.clone(),
60 ))
61 }
62
63 fn scan_file(&self, path: &Path, rules: &Rules) -> Result<FileScanResult> {
69 let rules = rules
70 .scan_file(path, 10)?
71 .into_iter()
72 .filter(|rule| {
73 let filetypes = rule.get_filetypes();
74 filetypes.is_empty()
75 || filetypes
76 .iter()
77 .any(|filetype| path.to_string_lossy().ends_with(filetype))
78 })
79 .map(RuleScore::from)
80 .collect();
81
82 Ok(FileScanResult::new(
83 self.relative_to_archive_root(path)?,
84 rules,
85 ))
86 }
87
88 fn relative_to_archive_root(&self, path: &Path) -> Result<PathBuf> {
90 Ok(path.strip_prefix(self.dir.path())?.to_path_buf())
91 }
92}
93
94#[derive(Debug)]
96pub struct DistributionScanResults {
97 file_scan_results: Vec<FileScanResult>,
99
100 inspector_url: Url,
102}
103
104impl DistributionScanResults {
105 pub fn new(file_scan_results: Vec<FileScanResult>, inspector_url: Url) -> Self {
108 Self {
109 file_scan_results,
110 inspector_url,
111 }
112 }
113
114 pub fn get_most_malicious_file(&self) -> Option<&FileScanResult> {
119 self.file_scan_results
120 .iter()
121 .max_by_key(|i| i.calculate_score())
122 }
123
124 fn get_matched_rules(&self) -> HashSet<&RuleScore> {
126 let mut rules: HashSet<&RuleScore> = HashSet::new();
127 for file_scan_result in &self.file_scan_results {
128 for rule in &file_scan_result.rules {
129 rules.insert(rule);
130 }
131 }
132
133 rules
134 }
135
136 pub fn get_total_score(&self) -> i64 {
138 self.get_matched_rules().iter().map(|rule| rule.score).sum()
139 }
140
141 pub fn get_matched_rule_identifiers(&self) -> Vec<&str> {
143 self.get_matched_rules()
144 .iter()
145 .map(|rule| rule.name.as_str())
146 .collect()
147 }
148
149 pub fn inspector_url(&self) -> Option<String> {
152 self.get_most_malicious_file().map(|file| {
153 format!(
154 "{}{}",
155 self.inspector_url.as_str(),
156 file.path.to_string_lossy().as_ref()
157 )
158 })
159 }
160}
161
162pub struct PackageScanResults {
163 pub name: String,
164 pub version: String,
165 pub distribution_scan_results: Vec<DistributionScanResults>,
166 pub commit_hash: String,
167}
168
169impl PackageScanResults {
170 pub fn new(
171 name: String,
172 version: String,
173 distribution_scan_results: Vec<DistributionScanResults>,
174 commit_hash: String,
175 ) -> Self {
176 Self {
177 name,
178 version,
179 distribution_scan_results,
180 commit_hash,
181 }
182 }
183
184 pub fn build_body(&self) -> SubmitJobResultsSuccess {
186 let highest_score_distribution = self
187 .distribution_scan_results
188 .iter()
189 .max_by_key(|distrib| distrib.get_total_score());
190
191 let score = highest_score_distribution
192 .map(DistributionScanResults::get_total_score)
193 .unwrap_or_default();
194
195 let inspector_url =
196 highest_score_distribution.and_then(DistributionScanResults::inspector_url);
197
198 let rules_matched = self
200 .distribution_scan_results
201 .iter()
202 .flat_map(DistributionScanResults::get_matched_rule_identifiers)
203 .map(std::string::ToString::to_string)
204 .collect::<HashSet<String>>()
205 .into_iter()
206 .collect();
207
208 SubmitJobResultsSuccess {
209 name: self.name.clone(),
210 version: self.version.clone(),
211 score,
212 inspector_url,
213 rules_matched,
214 commit: self.commit_hash.clone(),
215 }
216 }
217}
218
219pub fn scan_all_distributions(
223 http_client: &Client,
224 rules: &Rules,
225 job: &Job,
226) -> Result<Vec<DistributionScanResults>> {
227 let mut distribution_scan_results = Vec::with_capacity(job.distributions.len());
228 for distribution in &job.distributions {
229 let download_url: Url = distribution.parse().unwrap();
230 let inspector_url = create_inspector_url(&job.name, &job.version, &download_url);
231
232 let dir = download_distribution(http_client, download_url.clone())?;
233
234 let mut dist = Distribution { dir, inspector_url };
235 let distribution_scan_result = dist.scan(rules)?;
236 distribution_scan_results.push(distribution_scan_result);
237 }
238
239 Ok(distribution_scan_results)
240}
241
242#[cfg(test)]
243mod tests {
244 use super::{DistributionScanResults, PackageScanResults};
245 use crate::{
246 client::{ScanResultSerializer, SubmitJobResultsError, SubmitJobResultsSuccess},
247 scanner::{FileScanResult, RuleScore},
248 };
249 use std::io::Write;
250 use std::{collections::HashSet, path::PathBuf};
251 use tempfile::{tempdir, tempdir_in};
252 use yara::Compiler;
253
254 #[test]
255 fn test_scan_result_success_serialization() {
256 let success = SubmitJobResultsSuccess {
257 name: "test".into(),
258 version: "1.0.0".into(),
259 score: 10,
260 inspector_url: Some("inspector url".into()),
261 rules_matched: vec!["abc".into(), "def".into()],
262 commit: "commit hash".into(),
263 };
264
265 let scan_result: ScanResultSerializer = Ok(success).into();
266 let actual = serde_json::to_string(&scan_result).unwrap();
267 let expected = r#"{"name":"test","version":"1.0.0","score":10,"inspector_url":"inspector url","rules_matched":["abc","def"],"commit":"commit hash"}"#;
268
269 assert_eq!(actual, expected);
270 }
271
272 #[test]
273 fn test_scan_result_error_serialization() {
274 let error = SubmitJobResultsError {
275 name: "test".into(),
276 version: "1.0.0".into(),
277 reason: "Package too large".into(),
278 };
279
280 let scan_result: ScanResultSerializer = Err(error).into();
281 let actual = serde_json::to_string(&scan_result).unwrap();
282 let expected = r#"{"name":"test","version":"1.0.0","reason":"Package too large"}"#;
283
284 assert_eq!(actual, expected);
285 }
286
287 #[test]
288 fn test_file_score() {
289 let rules = vec![
290 RuleScore {
291 name: String::from("rule1"),
292 score: 5,
293 },
294 RuleScore {
295 name: String::from("rule2"),
296 score: 7,
297 },
298 ];
299
300 let file_scan_result = FileScanResult {
301 path: PathBuf::default(),
302 rules,
303 };
304 assert_eq!(file_scan_result.calculate_score(), 12);
305 }
306
307 #[test]
308 fn test_get_most_malicious_file() {
309 let file_scan_results = vec![
310 FileScanResult {
311 path: PathBuf::default(),
312 rules: vec![RuleScore {
313 name: String::from("rule1"),
314 score: 5,
315 }],
316 },
317 FileScanResult {
318 path: PathBuf::default(),
319 rules: vec![RuleScore {
320 name: String::from("rule2"),
321 score: 7,
322 }],
323 },
324 FileScanResult {
325 path: PathBuf::default(),
326 rules: vec![RuleScore {
327 name: String::from("rule3"),
328 score: 4,
329 }],
330 },
331 ];
332
333 let distribution_scan_results = DistributionScanResults {
334 file_scan_results,
335 inspector_url: reqwest::Url::parse("https://example.net").unwrap(),
336 };
337
338 assert_eq!(
339 distribution_scan_results
340 .get_most_malicious_file()
341 .unwrap()
342 .rules[0]
343 .name,
344 "rule2"
345 );
346 }
347
348 #[test]
349 fn test_get_matched_rules() {
350 let file_scan_results = vec![
351 FileScanResult {
352 path: PathBuf::default(),
353 rules: vec![
354 RuleScore {
355 name: String::from("rule1"),
356 score: 5,
357 },
358 RuleScore {
359 name: String::from("rule2"),
360 score: 7,
361 },
362 ],
363 },
364 FileScanResult {
365 path: PathBuf::default(),
366 rules: vec![
367 RuleScore {
368 name: String::from("rule2"),
369 score: 7,
370 },
371 RuleScore {
372 name: String::from("rule3"),
373 score: 9,
374 },
375 ],
376 },
377 FileScanResult {
378 path: PathBuf::default(),
379 rules: vec![
380 RuleScore {
381 name: String::from("rule3"),
382 score: 9,
383 },
384 RuleScore {
385 name: String::from("rule4"),
386 score: 6,
387 },
388 ],
389 },
390 ];
391
392 let distribution_scan_results = DistributionScanResults {
393 file_scan_results,
394 inspector_url: reqwest::Url::parse("https://example.net").unwrap(),
395 };
396
397 let matched_rules: HashSet<RuleScore> = distribution_scan_results
398 .get_matched_rules()
399 .into_iter()
400 .cloned()
401 .collect();
402
403 let expected_rules = HashSet::from([
404 RuleScore {
405 name: String::from("rule1"),
406 score: 5,
407 },
408 RuleScore {
409 name: String::from("rule2"),
410 score: 7,
411 },
412 RuleScore {
413 name: String::from("rule3"),
414 score: 9,
415 },
416 RuleScore {
417 name: String::from("rule4"),
418 score: 6,
419 },
420 ]);
421
422 assert_eq!(matched_rules, expected_rules);
423 }
424
425 #[test]
426 fn test_get_matched_rule_identifiers() {
427 let file_scan_results = vec![
428 FileScanResult {
429 path: PathBuf::default(),
430 rules: vec![
431 RuleScore {
432 name: String::from("rule1"),
433 score: 5,
434 },
435 RuleScore {
436 name: String::from("rule2"),
437 score: 7,
438 },
439 ],
440 },
441 FileScanResult {
442 path: PathBuf::default(),
443 rules: vec![
444 RuleScore {
445 name: String::from("rule2"),
446 score: 7,
447 },
448 RuleScore {
449 name: String::from("rule3"),
450 score: 9,
451 },
452 ],
453 },
454 FileScanResult {
455 path: PathBuf::default(),
456 rules: vec![
457 RuleScore {
458 name: String::from("rule3"),
459 score: 9,
460 },
461 RuleScore {
462 name: String::from("rule4"),
463 score: 6,
464 },
465 ],
466 },
467 ];
468
469 let distribution_scan_results = DistributionScanResults {
470 file_scan_results,
471 inspector_url: reqwest::Url::parse("https://example.net").unwrap(),
472 };
473
474 let matched_rule_identifiers = distribution_scan_results.get_matched_rule_identifiers();
475
476 let expected_rule_identifiers = vec!["rule1", "rule2", "rule3", "rule4"];
477
478 assert_eq!(
479 HashSet::<_>::from_iter(matched_rule_identifiers),
480 HashSet::<_>::from_iter(expected_rule_identifiers)
481 );
482 }
483
484 #[test]
485 fn test_build_package_scan_results_body() {
486 let file_scan_results1 = vec![
487 FileScanResult {
488 path: PathBuf::default(),
489 rules: vec![RuleScore {
490 name: String::from("rule1"),
491 score: 5,
492 }],
493 },
494 FileScanResult {
495 path: PathBuf::default(),
496 rules: vec![RuleScore {
497 name: String::from("rule2"),
498 score: 7,
499 }],
500 },
501 ];
502 let distribution_scan_results1 = DistributionScanResults {
503 file_scan_results: file_scan_results1,
504 inspector_url: reqwest::Url::parse("https://example.net/distrib1.tar.gz").unwrap(),
505 };
506
507 let file_scan_results2 = vec![
508 FileScanResult {
509 path: PathBuf::default(),
510 rules: vec![RuleScore {
511 name: String::from("rule3"),
512 score: 2,
513 }],
514 },
515 FileScanResult {
516 path: PathBuf::default(),
517 rules: vec![RuleScore {
518 name: String::from("rule4"),
519 score: 9,
520 }],
521 },
522 ];
523 let distribution_scan_results2 = DistributionScanResults {
524 file_scan_results: file_scan_results2,
525 inspector_url: reqwest::Url::parse("https://example.net/distrib2.whl").unwrap(),
526 };
527
528 let package_scan_results = PackageScanResults {
529 name: String::from("remmy"),
530 version: String::from("4.20.69"),
531 distribution_scan_results: vec![distribution_scan_results1, distribution_scan_results2],
532 commit_hash: String::from("abc"),
533 };
534
535 let body = package_scan_results.build_body();
536
537 assert_eq!(
538 body.inspector_url,
539 Some(String::from("https://example.net/distrib1.tar.gz"))
540 );
541 assert_eq!(body.score, 12);
542 assert_eq!(
543 HashSet::from([
544 "rule1".into(),
545 "rule2".into(),
546 "rule3".into(),
547 "rule4".into()
548 ]),
549 HashSet::from_iter(body.rules_matched)
550 );
551 }
552
553 #[test]
554 fn test_scan_file() {
555 let rules = r#"
556 rule contains_rust {
557 meta:
558 weight = 5
559 strings:
560 $rust = "rust" nocase
561 condition:
562 $rust
563 }
564 "#;
565
566 let compiler = Compiler::new().unwrap().add_rules_str(rules).unwrap();
567
568 let rules = compiler.compile_rules().unwrap();
569
570 let tempdir = tempdir().unwrap();
571 let archive_root = tempfile::Builder::new().tempdir_in(tempdir.path()).unwrap();
572
573 let mut tmpfile = tempfile::NamedTempFile::new_in(archive_root.path()).unwrap();
574
575 writeln!(&mut tmpfile, "I hate Rust >:(").unwrap();
576
577 let distro = super::Distribution {
578 dir: tempdir,
579 inspector_url: "https://example.com".parse().unwrap(),
580 };
581
582 let result = distro.scan_file(tmpfile.path(), &rules).unwrap();
583
584 assert_eq!(
585 result.rules[0],
586 RuleScore {
587 name: "contains_rust".into(),
588 score: 5
589 }
590 );
591 assert_eq!(result.calculate_score(), 5);
592 }
593
594 #[test]
595 fn test_relative_to_archive_root() {
596 let tempdir = tempdir().unwrap();
597
598 let input_path = &tempdir.path().join("name-version").join("README.md");
599 let expected_path = PathBuf::from("name-version/README.md");
600
601 let distro = super::Distribution {
602 dir: tempdir,
603 inspector_url: "https://example.com".parse().unwrap(),
604 };
605
606 let result = distro.relative_to_archive_root(input_path).unwrap();
607
608 assert_eq!(expected_path, result);
609 }
610
611 #[test]
612 fn scan_skips_directories() {
613 let rules = r#"
614 rule contains_rust {
615 meta:
616 weight = 5
617 strings:
618 $rust = "rust" nocase
619 condition:
620 $rust
621 }
622 "#;
623
624 let compiler = Compiler::new().unwrap().add_rules_str(rules).unwrap();
625
626 let rules = compiler.compile_rules().unwrap();
627 let tempdir = tempdir().unwrap();
628 let _subtempdir = tempdir_in(tempdir.path()).unwrap();
629 let mut tempfile = tempfile::NamedTempFile::new_in(tempdir.path()).unwrap();
630 writeln!(&mut tempfile, "rust").unwrap();
631
632 let mut distro = super::Distribution {
633 dir: tempdir,
634 inspector_url: "https://example.com".parse().unwrap(),
635 };
636
637 let results = distro.scan(&rules).unwrap();
638
639 assert_eq!(results.file_scan_results.len(), 1);
640 }
641}