Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: init gritql python #337

Merged
merged 28 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 17 additions & 7 deletions crates/cli/src/commands/apply_pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use marzano_messenger::{
output_mode::OutputMode,
};

use crate::resolver::{get_grit_files_from_cwd, GritModuleResolver};
use crate::resolver::{get_grit_files_from_flags_or_cwd, GritModuleResolver};
use crate::utils::has_uncommitted_changes;

use super::filters::SharedFilterArgs;
Expand Down Expand Up @@ -207,7 +207,7 @@ pub(crate) async fn run_apply_pattern(
details: &mut ApplyDetails,
pattern_libs: Option<BTreeMap<String, String>>,
default_lang: Option<PatternLanguage>,
format: &GlobalFormatFlags,
format_flags: &GlobalFormatFlags,
root_path: Option<PathBuf>,
) -> Result<()> {
let mut context = Updater::from_current_bin()
Expand All @@ -217,7 +217,7 @@ pub(crate) async fn run_apply_pattern(
.unwrap();

let format = OutputFormat::from_flags(
format,
format_flags,
if arg.stdin {
OutputFormat::Transformed
} else {
Expand Down Expand Up @@ -283,7 +283,7 @@ pub(crate) async fn run_apply_pattern(
let module_resolution = span!(tracing::Level::INFO, "module_resolution",).entered();

// Construct a resolver
let resolver = GritModuleResolver::new(cwd.to_str().unwrap());
let resolver = GritModuleResolver::new();
let current_repo_root = marzano_gritmodule::fetcher::LocalRepo::from_dir(&cwd)
.await
.map(|repo| repo.root())
Expand Down Expand Up @@ -312,7 +312,14 @@ pub(crate) async fn run_apply_pattern(
#[cfg(feature = "grit_tracing")]
let stdlib_download_span = span!(tracing::Level::INFO, "stdlib_download",).entered();

let mod_dir = find_grit_modules_dir(cwd.clone()).await;
let target_grit_dir = format_flags
.grit_dir
.as_ref()
.and_then(|c| c.parent())
.unwrap_or_else(|| &cwd)
.to_path_buf();
let mod_dir = find_grit_modules_dir(target_grit_dir.clone()).await;

if !env::var("GRIT_DOWNLOADS_DISABLED")
.unwrap_or_else(|_| "false".to_owned())
.parse::<bool>()
Expand All @@ -321,7 +328,7 @@ pub(crate) async fn run_apply_pattern(
{
flushable_unwrap!(
emitter,
init_config_from_cwd::<KeepFetcherKind>(cwd.clone(), false).await
init_config_from_cwd::<KeepFetcherKind>(target_grit_dir, false).await
);
}

Expand Down Expand Up @@ -349,7 +356,10 @@ pub(crate) async fn run_apply_pattern(
#[cfg(feature = "grit_tracing")]
let grit_file_discovery = span!(tracing::Level::INFO, "grit_file_discovery",).entered();

let pattern_libs = flushable_unwrap!(emitter, get_grit_files_from_cwd().await);
let pattern_libs = flushable_unwrap!(
emitter,
get_grit_files_from_flags_or_cwd(format_flags).await
);

let (mut lang, pattern_body) = if pattern.ends_with(".grit") || pattern.ends_with(".md") {
match fs::read_to_string(pattern.clone()).await {
Expand Down
9 changes: 6 additions & 3 deletions crates/cli/src/commands/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::{
github::{log_check_annotations, write_check_summary},
messenger_variant::create_emitter,
resolver::{
get_grit_files_from, get_grit_files_from_cwd, resolve_from, resolve_from_cwd,
get_grit_files_from, get_grit_files_from_flags_or_cwd, resolve_from, resolve_from_cwd,
GritModuleResolver, Source,
},
scan::log_check_json,
Expand Down Expand Up @@ -102,7 +102,10 @@ pub(crate) async fn run_check(
grit_files.merge(global_files);
(resolved, grit_files)
} else {
try_join![resolve_from_cwd(&Source::All), get_grit_files_from_cwd()]?
try_join![
resolve_from_cwd(&Source::All),
get_grit_files_from_flags_or_cwd(format)
]?
};

let enforced = resolved_patterns
Expand All @@ -122,7 +125,7 @@ pub(crate) async fn run_check(
let filter_range = extract_filter_ranges(&arg.shared_filters, Some(&current_dir))?;

// Construct a resolver
let resolver = GritModuleResolver::new(current_dir.to_str().unwrap());
let resolver = GritModuleResolver::new();

let mut body_to_pattern: HashMap<String, &ResolvedGritDefinition> = HashMap::new();
let compile_tasks: Result<HashMap<String, Problem>, _> = enforced
Expand Down
12 changes: 4 additions & 8 deletions crates/cli/src/commands/parse.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use crate::{
flags::GlobalFormatFlags,
jsonl::JSONLineMessenger,
resolver::{get_grit_files_from_cwd, GritModuleResolver},
};
use crate::{flags::GlobalFormatFlags, jsonl::JSONLineMessenger, resolver::GritModuleResolver};
use anyhow::{bail, Result};
use clap::Args;
use grit_util::Position;
Expand Down Expand Up @@ -85,12 +81,12 @@ pub(crate) async fn run_parse(
Ok(())
}

#[allow(deprecated)]
async fn parse_one_pattern(body: String, path: Option<&PathBuf>) -> Result<MatchResult> {
let current_dir = std::env::current_dir()?;
let resolver = GritModuleResolver::new(current_dir.to_str().unwrap());
let resolver = GritModuleResolver::new();
let lang = PatternLanguage::get_language(&body);
let pattern = resolver.make_pattern(&body, None)?;
let pattern_libs = get_grit_files_from_cwd().await?;
let pattern_libs = crate::resolver::get_grit_files_from_cwd().await?;
let pattern_libs = pattern_libs.get_language_directory_or_default(lang)?;
let problem = match pattern.compile(&pattern_libs, None, None, None) {
Ok(problem) => problem,
Expand Down
4 changes: 2 additions & 2 deletions crates/cli/src/commands/patterns_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use marzano_gritmodule::config::{DefinitionSource, ResolvedGritDefinition};
use crate::{
flags::GlobalFormatFlags,
lister::{list_applyables, Listable},
resolver::resolve_from_cwd,
resolver::{resolve_from_flags_or_cwd},
};

use super::list::ListArgs;
Expand Down Expand Up @@ -36,6 +36,6 @@ impl Listable for ResolvedGritDefinition {
}

pub(crate) async fn run_patterns_list(arg: ListArgs, parent: GlobalFormatFlags) -> Result<()> {
let (resolved, curr_repo) = resolve_from_cwd(&arg.source).await?;
let (resolved, curr_repo) = resolve_from_flags_or_cwd(&parent, &arg.source).await?;
list_applyables(false, false, resolved, arg.level, &parent, curr_repo).await
}
9 changes: 5 additions & 4 deletions crates/cli/src/commands/patterns_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIter
use serde::Serialize;

use crate::flags::{GlobalFormatFlags, OutputFormat};
use crate::resolver::{get_grit_files_from_cwd, resolve_from_cwd, GritModuleResolver, Source};
use crate::resolver::{
get_grit_files_from_flags_or_cwd, resolve_from_cwd, GritModuleResolver, Source,
};
use crate::result_formatting::FormattedResult;
use crate::updater::Updater;
use crate::ux::{indent, log_test_diff};
Expand All @@ -34,8 +36,7 @@ pub async fn get_marzano_pattern_test_results(
args: PatternsTestArgs,
output: OutputFormat,
) -> Result<()> {
let cwd = std::env::current_dir()?;
let resolver = GritModuleResolver::new(cwd.to_str().unwrap());
let resolver = GritModuleResolver::new();

let final_results: DashMap<String, Vec<WrappedResult>> = DashMap::new();
let unformatted_results: DashMap<PatternLanguage, Vec<WrappedResult>> = DashMap::new();
Expand Down Expand Up @@ -250,7 +251,7 @@ pub(crate) async fn run_patterns_test(
flags: GlobalFormatFlags,
) -> Result<()> {
let (mut patterns, _) = resolve_from_cwd(&Source::Local).await?;
let libs = get_grit_files_from_cwd().await?;
let libs = get_grit_files_from_flags_or_cwd(&flags).await?;

if arg.filter.is_some() {
let filter = arg.filter.as_ref().unwrap();
Expand Down
3 changes: 3 additions & 0 deletions crates/cli/src/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ pub struct GlobalFormatFlags {
/// Override the default log level (info)
#[arg(long, global = true)]
pub log_level: Option<log::LevelFilter>,
/// Override the default .grit directory location
#[arg(long, global = true)]
pub grit_dir: Option<std::path::PathBuf>,
}

#[derive(Debug, PartialEq, Clone)]
Expand Down
94 changes: 61 additions & 33 deletions crates/cli/src/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ use colored::Colorize;
use core::fmt;
use log::{info, warn};
use serde::Serialize;
use std::{collections::HashMap, path::PathBuf, str::FromStr};
use std::{
collections::HashMap,
path::{Path, PathBuf},
str::FromStr,
};

use anyhow::{Context, Result};
use marzano_gritmodule::{
Expand All @@ -14,7 +18,7 @@ use marzano_gritmodule::{
searcher::find_grit_dir_from,
};

use crate::updater::Updater;
use crate::{flags::GlobalFormatFlags, updater::Updater};

#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Serialize, Debug)]
#[serde(rename_all = "lowercase")]
Expand All @@ -28,8 +32,24 @@ pub enum Source {
}

// Equivalent to our PatternResolver in zesty, but more minimal
pub struct GritModuleResolver<'a> {
_root_directory: &'a str,
pub struct GritModuleResolver {}

impl GritModuleResolver {
pub fn new() -> Self {
Self {}
}

pub fn make_pattern<'b>(
&self,
pattern_input: &'b str,
name: Option<String>,
) -> Result<RichPattern<'b>> {
let pattern = RichPattern {
body: pattern_input,
name,
};
Ok(pattern)
}
}

#[derive(Debug)]
Expand All @@ -44,24 +64,16 @@ impl<'b> fmt::Display for RichPattern<'b> {
}
}

impl<'a> GritModuleResolver<'a> {
pub fn new(root_directory: &'a str) -> Self {
Self {
_root_directory: root_directory,
}
}
async fn from_known_grit_dir(config_path: &Path) -> Result<PatternsDirectory> {
let stdlib_modules = get_stdlib_modules();

pub fn make_pattern<'b>(
&self,
pattern_input: &'b str,
name: Option<String>,
) -> Result<RichPattern<'b>> {
let pattern = RichPattern {
body: pattern_input,
name,
};
Ok(pattern)
}
let grit_parent = PathBuf::from(config_path.parent().context(format!(
"Unable to find parent of .grit directory at {}",
config_path.to_string_lossy()
))?);
let parent_str = &grit_parent.to_string_lossy().to_string();
let repo = ModuleRepo::from_dir(config_path).await;
get_grit_files(&repo, parent_str, Some(stdlib_modules)).await
}

pub async fn get_grit_files_from(cwd: Option<PathBuf>) -> Result<PatternsDirectory> {
Expand All @@ -70,20 +82,12 @@ pub async fn get_grit_files_from(cwd: Option<PathBuf>) -> Result<PatternsDirecto
} else {
None
};
let stdlib_modules = get_stdlib_modules();

match existing_config {
Some(config) => {
let config_path = PathBuf::from_str(&config).unwrap();
let grit_parent = PathBuf::from(config_path.parent().context(format!(
"Unable to find parent of .grit directory at {}",
config
))?);
let parent_str = &grit_parent.to_string_lossy().to_string();
let repo = ModuleRepo::from_dir(&config_path).await;
get_grit_files(&repo, parent_str, Some(stdlib_modules)).await
}
Some(config) => from_known_grit_dir(&PathBuf::from(config)).await,
None => {
let stdlib_modules = get_stdlib_modules();

let updater = Updater::from_current_bin().await?;
let install_path = updater.install_path;
let repo = ModuleRepo::from_dir(&install_path).await;
Expand All @@ -92,12 +96,36 @@ pub async fn get_grit_files_from(cwd: Option<PathBuf>) -> Result<PatternsDirecto
}
}

#[tracing::instrument]
/// Get the grit files from the current working directory
#[deprecated = "Use get_grit_files_from_flags_or_cwd instead"]
pub async fn get_grit_files_from_cwd() -> Result<PatternsDirectory> {
let cwd = std::env::current_dir()?;
get_grit_files_from(Some(cwd)).await
}

#[tracing::instrument]
pub async fn get_grit_files_from_flags_or_cwd(
flags: &GlobalFormatFlags,
) -> Result<PatternsDirectory> {
if let Some(grit_dir) = &flags.grit_dir {
from_known_grit_dir(grit_dir).await
} else {
let cwd = std::env::current_dir()?;
get_grit_files_from(Some(cwd)).await
}
}

pub async fn resolve_from_flags_or_cwd(
flags: &GlobalFormatFlags,
source: &Source,
) -> Result<(Vec<ResolvedGritDefinition>, ModuleRepo)> {
if let Some(grit_dir) = &flags.grit_dir {
resolve_from(grit_dir.clone(), source).await
} else {
resolve_from_cwd(source).await
}
}

pub async fn resolve_from(
cwd: PathBuf,
source: &Source,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
language python

pattern special_pattern() {
`os.getenv` => `dotenv.mygoodness`
}
40 changes: 40 additions & 0 deletions crates/cli_bin/tests/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2138,6 +2138,46 @@ fn ignores_file_in_grit_dir() -> Result<()> {
Ok(())
}

#[test]
fn override_grit_modules_at_apply() -> Result<()> {
// Grab the other grit directory
let (_temp_dir, other_dir) = get_fixture("override_custom_grit_dir", true)?;

// Keep _temp_dir around so that the tempdir is not deleted
let (_temp_dir, dir) = get_fixture("simple_python", false)?;
let origin_content = std::fs::read_to_string(dir.join("main.py"))?;

// from the tempdir as cwd, run marzano apply
let mut apply_cmd = get_test_cmd()?;
apply_cmd.current_dir(dir.as_path());
apply_cmd
.arg("apply")
.arg("--force")
.arg("special_pattern")
.arg("--grit-dir")
.arg(other_dir.join(".grit"));
let output = apply_cmd.output()?;

let stdout = String::from_utf8(output.stdout)?;
println!("stdout: {:?}", stdout);
let stderr = String::from_utf8(output.stderr)?;
println!("stderr: {:?}", stderr);

// Assert that the command failed
assert!(output.status.success(),);

// Read back the main.py file
let target_file = dir.join("main.py");
let content: String = std::fs::read_to_string(target_file)?;

assert_ne!(origin_content, content);

// Make sure it now has dotenv.mygoodness
assert!(content.contains("dotenv.mygoodness"));

Ok(())
}

#[test]
fn language_option_file_pattern_apply() -> Result<()> {
// Keep _temp_dir around so that the tempdir is not deleted
Expand Down