Skip to content

Commit

Permalink
Merge pull request #37 from Quramy/fix_optional_relation_type
Browse files Browse the repository at this point in the history
Fix optional relation type
  • Loading branch information
Quramy committed Nov 24, 2022
2 parents f298bd1 + 84bab8d commit 85505d5
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 41 deletions.
2 changes: 1 addition & 1 deletion examples/example-prj/src/__generated__/fabbrica/index.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

64 changes: 36 additions & 28 deletions packages/prisma-fabbrica/src/templates/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ import { template } from "talt";

type StripCreate<T extends string> = T extends `create${infer S}` ? Uncapitalize<S> : T;

function byName<T extends { readonly name: string }>(name: string | { readonly name: string }) {
return (x: T) => x.name === (typeof name === "string" ? name : name.name);
}

function camelize(pascal: string) {
return pascal[0].toLowerCase() + pascal.slice(1);
}
Expand Down Expand Up @@ -56,6 +60,12 @@ function filterRequiredInputObjectTypeField(inputType: DMMF.InputType) {
return filterRequiredFields(inputType).filter(isInputObjectTypeField);
}

function filterBelongsToField(model: DMMF.Model, inputType: DMMF.InputType) {
return inputType.fields
.filter(isInputObjectTypeField)
.filter(field => model.fields.find(byName(field))?.isList === false);
}

function filterEnumFields(inputType: DMMF.InputType) {
return inputType.fields.filter(
field =>
Expand All @@ -65,7 +75,7 @@ function filterEnumFields(inputType: DMMF.InputType) {

function extractFirstEnumValue(enums: DMMF.SchemaEnum[], field: DMMF.SchemaArg) {
const typeName = field.inputTypes[0].type;
const found = enums.find(e => e.name === field.inputTypes[0].type);
const found = enums.find(byName(typeName));
if (!found) {
throw new Error(`Not found enum ${typeName}`);
}
Expand All @@ -88,7 +98,7 @@ export const importStatement = (specifier: string, prismaClientModuleSpecifier:
`();

export const scalarFieldType = (
modelName: string,
model: DMMF.Model,
fieldName: string,
inputType: DMMF.SchemaArgInputType,
): ts.TypeNode => {
Expand Down Expand Up @@ -116,14 +126,14 @@ export const scalarFieldType = (
// return template.typeNode`Prisma.Json`();
return ast.keywordTypeNode(ts.SyntaxKind.AnyKeyword);
default:
throw new Error(`Unknown scalar type "${inputType.type}" for ${modelName}.${fieldName} .`);
throw new Error(`Unknown scalar type "${inputType.type}" for ${model.name}.${fieldName} .`);
}
};

export const argInputType = (modelName: string, fieldName: string, inputType: DMMF.SchemaArgInputType): ts.TypeNode => {
export const argInputType = (model: DMMF.Model, fieldName: string, inputType: DMMF.SchemaArgInputType): ts.TypeNode => {
const fieldType = () => {
if (inputType.location === "scalar") {
return scalarFieldType(modelName, fieldName, inputType);
return scalarFieldType(model, fieldName, inputType);
} else if (inputType.location === "enumTypes") {
return ast.typeReferenceNode(ast.identifier(inputType.type as string));
} else if (inputType.location === "outputObjectTypes" || inputType.location === "inputObjectTypes") {
Expand All @@ -140,7 +150,7 @@ export const argInputType = (modelName: string, fieldName: string, inputType: DM
: fieldType();
};

export const modelScalarOrEnumFields = (modelName: string, inputType: DMMF.InputType) =>
export const modelScalarOrEnumFields = (model: DMMF.Model, inputType: DMMF.InputType) =>
template.statement<ts.TypeAliasDeclaration>`
type MODEL_SCALAR_OR_ENUM_FIELDS = ${() =>
ast.typeLiteralNode(
Expand All @@ -149,18 +159,16 @@ export const modelScalarOrEnumFields = (modelName: string, inputType: DMMF.Input
undefined,
field.name,
undefined,
ast.unionTypeNode(
field.inputTypes.map(childInputType => argInputType(modelName, field.name, childInputType)),
),
ast.unionTypeNode(field.inputTypes.map(childInputType => argInputType(model, field.name, childInputType))),
),
),
)}
`({
MODEL_SCALAR_OR_ENUM_FIELDS: ast.identifier(`${modelName}ScalarOrEnumFields`),
MODEL_SCALAR_OR_ENUM_FIELDS: ast.identifier(`${model.name}ScalarOrEnumFields`),
});

export const modelBelongsToRelationFactory = (fieldType: DMMF.SchemaArg, model: DMMF.Model) => {
const targetModel = model.fields.find(f => f.name === fieldType.name)!;
const targetModel = model.fields.find(byName(fieldType))!;
return template.statement<ts.TypeAliasDeclaration>`
type ${() => ast.identifier(`${model.name}${fieldType.name}Factory`)} = {
_factoryFor: ${() => ast.literalTypeNode(ast.stringLiteral(targetModel.type))};
Expand All @@ -170,7 +178,7 @@ export const modelBelongsToRelationFactory = (fieldType: DMMF.SchemaArg, model:
`();
};

export const modelFactoryDefineInput = (modelName: string, inputType: DMMF.InputType) =>
export const modelFactoryDefineInput = (model: DMMF.Model, inputType: DMMF.InputType) =>
template.statement<ts.TypeAliasDeclaration>`
type MODEL_FACTORY_DEFINE_INPUT = ${() =>
ast.typeLiteralNode(
Expand All @@ -180,16 +188,17 @@ export const modelFactoryDefineInput = (modelName: string, inputType: DMMF.Input
field.name,
!field.isRequired || isScalarOrEnumField(field) ? ast.token(ts.SyntaxKind.QuestionToken) : undefined,
ast.unionTypeNode([
...(field.isRequired && isInputObjectTypeField(field)
? [ast.typeReferenceNode(ast.identifier(`${modelName}${field.name}Factory`))]
...((field.isRequired || model.fields.find(byName(field))!.isList === false) &&
isInputObjectTypeField(field)
? [ast.typeReferenceNode(ast.identifier(`${model.name}${field.name}Factory`))]
: []),
...field.inputTypes.map(childInputType => argInputType(modelName, field.name, childInputType)),
...field.inputTypes.map(childInputType => argInputType(model, field.name, childInputType)),
]),
),
),
)};
`({
MODEL_FACTORY_DEFINE_INPUT: ast.identifier(`${modelName}FactoryDefineInput`),
MODEL_FACTORY_DEFINE_INPUT: ast.identifier(`${model.name}FactoryDefineInput`),
});

export const modelFactoryDefineOptions = (modelName: string, isOpionalDefaultData: boolean) =>
Expand All @@ -212,12 +221,13 @@ export const modelFactoryDefineOptions = (modelName: string, isOpionalDefaultDat
});

export const isModelAssociationFactory = (fieldType: DMMF.SchemaArg, model: DMMF.Model) => {
const targetModel = model.fields.find(f => f.name === fieldType.name)!;
const targetModel = model.fields.find(byName(fieldType))!;
return template.statement<ts.FunctionDeclaration>`
function ${() => ast.identifier(`is${model.name}${fieldType.name}Factory`)}(
x: MODEL_BELONGS_TO_RELATION_FACTORY | ${() => argInputType(model.name, fieldType.name, fieldType.inputTypes[0])}
x: MODEL_BELONGS_TO_RELATION_FACTORY | ${() =>
argInputType(model, fieldType.name, fieldType.inputTypes[0])} | undefined
): x is MODEL_BELONGS_TO_RELATION_FACTORY {
return (x as any)._factoryFor === ${() => ast.stringLiteral(targetModel.type)};
return (x as any)?._factoryFor === ${() => ast.stringLiteral(targetModel.type)};
}
`({
MODEL_BELONGS_TO_RELATION_FACTORY: ast.typeReferenceNode(`${model.name}${fieldType.name}Factory`),
Expand All @@ -237,10 +247,10 @@ export const autoGenerateModelScalarsOrEnumsFieldArgs = (
MODEL_NAME: ast.stringLiteral(model.name),
FIELD_NAME: ast.stringLiteral(field.name),
IS_ID:
model.fields.find(f => f.name === field.name)!.isId || model.primaryKey?.fields.includes(field.name)
model.fields.find(byName(field))!.isId || model.primaryKey?.fields.includes(field.name)
? ast.true()
: ast.false(),
IS_UNIQUE: model.fields.find(f => f.name === field.name)!.isUnique ? ast.true() : ast.false(),
IS_UNIQUE: model.fields.find(byName(field))!.isUnique ? ast.true() : ast.false(),
})
: ast.stringLiteral(extractFirstEnumValue(enums, field));

Expand Down Expand Up @@ -279,7 +289,7 @@ export const defineModelFactoryInernal = (model: DMMF.Model, inputType: DMMF.Inp
const defaultData= await resolveValue(defaultDataResolver ?? {});
const defaultAssociations = ${() =>
ast.objectLiteralExpression(
filterRequiredInputObjectTypeField(inputType).map(field =>
filterBelongsToField(model, inputType).map(field =>
ast.propertyAssignment(
field.name,
template.expression`
Expand Down Expand Up @@ -373,15 +383,13 @@ export function getSourceFile({
...document.datamodel.models
.map(model => ({ model, createInputType: findPrsimaCreateInputTypeFromModelName(document, model.name) }))
.flatMap(({ model, createInputType }) => [
modelScalarOrEnumFields(model.name, createInputType),
...filterRequiredInputObjectTypeField(createInputType).map(fieldType =>
modelScalarOrEnumFields(model, createInputType),
...filterBelongsToField(model, createInputType).map(fieldType =>
modelBelongsToRelationFactory(fieldType, model),
),
modelFactoryDefineInput(model.name, createInputType),
modelFactoryDefineInput(model, createInputType),
modelFactoryDefineOptions(model.name, filterRequiredInputObjectTypeField(createInputType).length === 0),
...filterRequiredInputObjectTypeField(createInputType).map(fieldType =>
isModelAssociationFactory(fieldType, model),
),
...filterBelongsToField(model, createInputType).map(fieldType => isModelAssociationFactory(fieldType, model)),
autoGenerateModelScalarsOrEnums(model, createInputType, document.schema.enumTypes.model ?? []),
defineModelFactoryInernal(model, createInputType),
defineModelFactory(model.name, createInputType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ describe(modelScalarOrEnumFields, () => {
});
const inputType = findPrsimaCreateInputTypeFromModelName(dmmf, "TestModel");
const source = template.statement(expected)();
expect(printNode(modelScalarOrEnumFields("TestModel", inputType))).toBe(printNode(source).trim());
expect(printNode(modelScalarOrEnumFields(dmmf.datamodel.models[0], inputType))).toBe(printNode(source).trim());
});

it("does not generate for nullable field", async () => {
Expand All @@ -70,6 +70,6 @@ describe(modelScalarOrEnumFields, () => {
id: number;
}
`();
expect(printNode(modelScalarOrEnumFields("TestModel", inputType))).toBe(printNode(expected).trim());
expect(printNode(modelScalarOrEnumFields(dmmf.datamodel.models[0], inputType))).toBe(printNode(expected).trim());
});
});

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@ export const ReviewFactory = defineReviewFactory({
reviewer: UserFactory,
},
});
export const PostFactoryAlt = definePostFactory({
defaultData: {
author: UserFactory,
},
});

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 85505d5

Please sign in to comment.