diff --git a/cypress/integration/rendering/newShapes.spec.ts b/cypress/integration/rendering/newShapes.spec.ts
index 700efc67f..e26878f40 100644
--- a/cypress/integration/rendering/newShapes.spec.ts
+++ b/cypress/integration/rendering/newShapes.spec.ts
@@ -40,7 +40,28 @@ const newShapesSet5 = [
'card',
'shadedProcess',
] as const;
-const newShapesSet6 = ['curlyBraces'];
+
+const newShapesSet6 = ['roundedRect', 'squareRect', 'stateStart', 'stateEnd', 'labelRect'] as const;
+
+const newShapesSet7 = ['forkJoin', 'choice', 'note', 'stadium'] as const;
+
+const newShapesSet8 = [
+ 'question',
+ 'hexagon',
+ 'curlyBraces',
+ 'multiRect',
+ 'waveEdgedRectangle',
+] as const;
+
+const newShapesSet9 = ['anchor', 'lean_right', 'lean_left', 'trapezoid', 'inv_trapezoid'] as const;
+
+const newShapesSet10 = [
+ 'subroutine',
+ 'cylinder',
+ 'circle',
+ 'doublecircle',
+ 'rect_left_inv_arrow',
+] as const;
// Aggregate all shape sets into a single array
const newShapesSets = [
@@ -50,6 +71,10 @@ const newShapesSets = [
newShapesSet4,
newShapesSet5,
newShapesSet6,
+ newShapesSet7,
+ newShapesSet8,
+ newShapesSet9,
+ newShapesSet10,
] as const;
looks.forEach((look) => {
@@ -67,7 +92,7 @@ looks.forEach((look) => {
it(`with label`, () => {
let flowchartCode = `flowchart ${direction}\n`;
newShapesSet.forEach((newShape, index) => {
- flowchartCode += ` n${index} --> n${index}${index}@{ shape: ${newShape}, label: 'This is a label' }@\n`;
+ flowchartCode += ` n${index} --> n${index}${index}@{ shape: ${newShape}, label: 'This is a label for ${newShape} shape' }@\n`;
});
imgSnapshotTest(flowchartCode, { look });
});
@@ -75,7 +100,7 @@ looks.forEach((look) => {
it(`connect all shapes with each other`, () => {
let flowchartCode = `flowchart ${direction}\n`;
newShapesSet.forEach((newShape, index) => {
- flowchartCode += ` n${index}${index}@{ shape: ${newShape}, label: 'This is a label' }@\n`;
+ flowchartCode += ` n${index}${index}@{ shape: ${newShape}, label: 'This is a label for ${newShape} shape' }@\n`;
});
for (let i = 0; i < newShapesSet.length; i++) {
for (let j = i + 1; j < newShapesSet.length; j++) {
@@ -88,7 +113,7 @@ looks.forEach((look) => {
it(`with very long label`, () => {
let flowchartCode = `flowchart ${direction}\n`;
newShapesSet.forEach((newShape, index) => {
- flowchartCode += ` n${index} --> n${index}${index}@{ shape: ${newShape}, label: 'This is a very very very very very long long long label' }@\n`;
+ flowchartCode += ` n${index} --> n${index}${index}@{ shape: ${newShape}, label: 'This is a very very very very very long long long label for ${newShape} shape' }@\n`;
});
imgSnapshotTest(flowchartCode, { look });
});
@@ -96,7 +121,7 @@ looks.forEach((look) => {
it(`with markdown htmlLabels:true`, () => {
let flowchartCode = `flowchart ${direction}\n`;
newShapesSet.forEach((newShape, index) => {
- flowchartCode += ` n${index} --> n${index}${index}@{ shape: ${newShape}, label: 'This is **bold** and strong' }@\n`;
+ flowchartCode += ` n${index} --> n${index}${index}@{ shape: ${newShape}, label: 'This is **bold** and strong for ${newShape} shape' }@\n`;
});
imgSnapshotTest(flowchartCode, { look });
});
@@ -104,7 +129,7 @@ looks.forEach((look) => {
it(`with markdown htmlLabels:false`, () => {
let flowchartCode = `flowchart ${direction}\n`;
newShapesSet.forEach((newShape, index) => {
- flowchartCode += ` n${index} --> n${index}${index}@{ shape: ${newShape}, label: 'This is **bold** and strong' }@\n`;
+ flowchartCode += ` n${index} --> n${index}${index}@{ shape: ${newShape}, label: 'This is **bold** and strong for ${newShape} shape' }@\n`;
});
imgSnapshotTest(flowchartCode, {
look,
@@ -116,7 +141,7 @@ looks.forEach((look) => {
it(`with styles`, () => {
let flowchartCode = `flowchart ${direction}\n`;
newShapesSet.forEach((newShape, index) => {
- flowchartCode += ` n${index} --> n${index}${index}@{ shape: ${newShape}, label: 'new shape' }@\n`;
+ flowchartCode += ` n${index} --> n${index}${index}@{ shape: ${newShape}, label: 'new ${newShape} shape' }@\n`;
flowchartCode += ` style n${index}${index} fill:#f9f,stroke:#333,stroke-width:4px \n`;
});
imgSnapshotTest(flowchartCode, { look });
@@ -126,7 +151,7 @@ looks.forEach((look) => {
let flowchartCode = `flowchart ${direction}\n`;
flowchartCode += ` classDef customClazz fill:#bbf,stroke:#f66,stroke-width:2px,color:#fff,stroke-dasharray: 5 5\n`;
newShapesSet.forEach((newShape, index) => {
- flowchartCode += ` n${index} --> n${index}${index}@{ shape: ${newShape}, label: 'new shape' }@\n`;
+ flowchartCode += ` n${index} --> n${index}${index}@{ shape: ${newShape}, label: 'new ${newShape} shape' }@\n`;
flowchartCode += ` n${index}${index}:::customClazz\n`;
});
imgSnapshotTest(flowchartCode, { look });
diff --git a/packages/mermaid/src/rendering-util/rendering-elements/shapes/choice.ts b/packages/mermaid/src/rendering-util/rendering-elements/shapes/choice.ts
index 3d6f085a4..ff3e2998f 100644
--- a/packages/mermaid/src/rendering-util/rendering-elements/shapes/choice.ts
+++ b/packages/mermaid/src/rendering-util/rendering-elements/shapes/choice.ts
@@ -3,20 +3,17 @@ import type { Node } from '$root/rendering-util/types.d.ts';
import type { SVG } from '$root/diagram-api/types.js';
// @ts-ignore TODO: Fix rough typings
import rough from 'roughjs';
-import { solidStateFill, styles2String } from './handDrawnShapeStyles.js';
-import { getConfig } from '$root/diagram-api/diagramAPI.js';
+import { styles2String, userNodeOverrides } from './handDrawnShapeStyles.js';
+import { createPathFromPoints, getNodeClasses, labelHelper } from './util.js';
-export const choice = (parent: SVG, node: Node) => {
- const { labelStyles, nodeStyles } = styles2String(node);
- node.labelStyle = labelStyles;
- const { themeVariables } = getConfig();
- const { lineColor } = themeVariables;
- const shapeSvg = parent
- .insert('g')
- .attr('class', 'node default')
- .attr('id', node.domId || node.id);
+export const choice = async (parent: SVG, node: Node) => {
+ const { nodeStyles } = styles2String(node);
+ node.label = '';
+ const { shapeSvg } = await labelHelper(parent, node, getNodeClasses(node));
+ const { cssStyles } = node;
+
+ const s = Math.max(28, node.width ?? 0);
- const s = 28;
const points = [
{ x: 0, y: s / 2 },
{ x: s / 2, y: 0 },
@@ -24,40 +21,34 @@ export const choice = (parent: SVG, node: Node) => {
{ x: -s / 2, y: 0 },
];
- let choice;
- if (node.look === 'handDrawn') {
- // @ts-ignore TODO: Fix rough typings
- const rc = rough.svg(shapeSvg);
- const pointArr = points.map(function (d) {
- return [d.x, d.y];
- });
- const roughNode = rc.polygon(pointArr, solidStateFill(lineColor));
- choice = shapeSvg.insert(() => roughNode);
- } else {
- choice = shapeSvg.insert('polygon', ':first-child').attr(
- 'points',
- points
- .map(function (d) {
- return d.x + ',' + d.y;
- })
- .join(' ')
- );
+ // @ts-ignore TODO: Fix rough typings
+ const rc = rough.svg(shapeSvg);
+ const options = userNodeOverrides(node, {});
+
+ if (node.look !== 'handDrawn') {
+ options.roughness = 0;
+ options.fillStyle = 'solid';
}
- // center the circle around its coordinate
- choice
- .attr('class', 'state-start')
- // @ts-ignore TODO: Fix rough typings
- .attr('r', 7)
- .attr('width', 28)
- .attr('height', 28)
- .attr('style', nodeStyles);
+ const choicePath = createPathFromPoints(points);
+ const roughNode = rc.path(choicePath, options);
+ const choiceShape = shapeSvg.insert(() => roughNode, ':first-child');
+
+ choiceShape.attr('class', 'basic label-container');
+
+ if (cssStyles && node.look !== 'handDrawn') {
+ choiceShape.selectAll('path').attr('style', cssStyles);
+ }
+
+ if (nodeStyles && node.look !== 'handDrawn') {
+ choiceShape.selectAll('path').attr('style', nodeStyles);
+ }
node.width = 28;
node.height = 28;
node.intersect = function (point) {
- return intersect.circle(node, 14, point);
+ return intersect.polygon(node, points, point);
};
return shapeSvg;
diff --git a/packages/mermaid/src/rendering-util/rendering-elements/shapes/forkJoin.ts b/packages/mermaid/src/rendering-util/rendering-elements/shapes/forkJoin.ts
index 07978be10..687d281f0 100644
--- a/packages/mermaid/src/rendering-util/rendering-elements/shapes/forkJoin.ts
+++ b/packages/mermaid/src/rendering-util/rendering-elements/shapes/forkJoin.ts
@@ -1,62 +1,50 @@
-import { updateNodeBounds } from './util.js';
+import { getNodeClasses, labelHelper, updateNodeBounds } from './util.js';
import intersect from '../intersect/index.js';
import type { Node } from '$root/rendering-util/types.d.ts';
import type { SVG } from '$root/diagram-api/types.js';
import rough from 'roughjs';
-import { solidStateFill } from './handDrawnShapeStyles.js';
-import { getConfig } from '$root/diagram-api/diagramAPI.js';
+import { styles2String, userNodeOverrides } from './handDrawnShapeStyles.js';
-export const forkJoin = (parent: SVG, node: Node, dir: string) => {
- const { themeVariables } = getConfig();
- const { lineColor } = themeVariables;
- const shapeSvg = parent
- .insert('g')
- .attr('class', 'node default')
- .attr('id', node.domId || node.id);
+export const forkJoin = async (parent: SVG, node: Node, dir: string) => {
+ const { nodeStyles } = styles2String(node);
+ node.label = '';
+ const { shapeSvg } = await labelHelper(parent, node, getNodeClasses(node));
- let width = 70;
- let height = 10;
+ const { cssStyles } = node;
+ let width = Math.max(70, node?.width ?? 0);
+ let height = Math.max(10, node?.height ?? 0);
if (dir === 'LR') {
- width = 10;
- height = 70;
+ width = Math.max(10, node?.width ?? 0);
+ height = Math.max(70, node?.height ?? 0);
}
+
const x = (-1 * width) / 2;
const y = (-1 * height) / 2;
- let shape;
- if (node.look === 'handDrawn') {
- // @ts-ignore TODO: Fix rough typings
- const rc = rough.svg(shapeSvg);
- const roughNode = rc.rectangle(x, y, width, height, solidStateFill(lineColor));
- shape = shapeSvg.insert(() => roughNode);
- } else {
- shape = shapeSvg
- .append('rect')
- .attr('x', x)
- .attr('y', y)
- .attr('width', width)
- .attr('height', height)
- .attr('class', 'fork-join');
+ // @ts-ignore TODO: Fix rough typings
+ const rc = rough.svg(shapeSvg);
+ const options = userNodeOverrides(node, {});
+
+ if (node.look !== 'handDrawn') {
+ options.roughness = 0;
+ options.fillStyle = 'solid';
+ }
+
+ const roughNode = rc.rectangle(x, y, width, height, options);
+
+ const shape = shapeSvg.insert(() => roughNode, ':first-child');
+
+ if (cssStyles && node.look !== 'handDrawn') {
+ shape.selectAll('path').attr('style', cssStyles);
+ }
+
+ if (nodeStyles && node.look !== 'handDrawn') {
+ shape.selectAll('path').attr('style', nodeStyles);
}
updateNodeBounds(node, shape);
- let nodeHeight = 0;
- let nodeWidth = 0;
- let nodePadding = 10;
- if (node.height) {
- nodeHeight = node.height;
- }
- if (node.width) {
- nodeWidth = node.width;
- }
- if (node.padding) {
- nodePadding = node.padding;
- }
-
- node.height = nodeHeight + nodePadding / 2;
- node.width = nodeWidth + nodePadding / 2;
node.intersect = function (point) {
return intersect.rect(node, point);
};