Skip to content

Commit b3198ac

Browse files
author
Johannes Coetzee
authored
Make javasrc method call AST structure more consistent with c2cpg (joernio#828)
* Make javasrc method call AST structure more consistent with c2cpg * Fix formatting * Suppress unused warning * Fix test indices * Fix more indices
1 parent 7eced91 commit b3198ac

File tree

10 files changed

+200
-43
lines changed

10 files changed

+200
-43
lines changed

console/src/main/scala/io/joern/console/QueryDatabase.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class QueryDatabase(defaultArgumentProvider: DefaultArgumentProvider = new Defau
139139
* */
140140
class DefaultArgumentProvider {
141141

142-
def typeSpecificDefaultArg(argTypeFullName: String): Option[Any] = {
142+
def typeSpecificDefaultArg(@unused argTypeFullName: String): Option[Any] = {
143143
None
144144
}
145145

dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package io.joern.dataflowengineoss.passes.reachingdef
22

3-
import io.shiftleft.codepropertygraph.generated.EdgeTypes
3+
import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Operators}
44
import io.shiftleft.codepropertygraph.generated.nodes._
55
import io.shiftleft.semanticcpg.language._
66
import io.shiftleft.semanticcpg.utils.MemberAccess.isGenericMemberAccessName
@@ -221,8 +221,19 @@ class ReachingDefTransferFunction(flowGraph: ReachingDefFlowGraph) extends Trans
221221
allIdentifiers(param.name)
222222
.filter(x => x.id != param.id)
223223
case identifier: Identifier =>
224-
allIdentifiers(identifier.name)
224+
val sameIdentifiers = allIdentifiers(identifier.name)
225225
.filter(x => x.id != identifier.id)
226+
227+
/**
228+
* Killing an identifier should also kill field accesses on that identifier.
229+
* For example, a reassignment `x = new Box()` should kill any previous
230+
* calls to `x.value`, `x.length()`, etc.
231+
*/
232+
val sameObjects: Iterable[Call] = allCalls.values.flatten
233+
.filter(_.name == Operators.fieldAccess)
234+
.filter(_.ast.isIdentifier.name(identifier.name).nonEmpty)
235+
236+
sameIdentifiers ++ sameObjects
226237
case call: Call =>
227238
allCalls(call.code)
228239
.filter(x => x.id != call.id)

joern-cli/frontends/c2cpg/src/test/scala/io/joern/c2cpg/querying/CDataFlowTests.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ class CDataFlowTests32 extends DataFlowCodeToCpgSuite {
980980
}
981981
}
982982

983-
class CDatFlowTests33 extends DataFlowCodeToCpgSuite {
983+
class CDataFlowTests33 extends DataFlowCodeToCpgSuite {
984984
override val code =
985985
"""
986986
|int bar(int z) {
@@ -995,7 +995,7 @@ class CDatFlowTests33 extends DataFlowCodeToCpgSuite {
995995
|}
996996
|""".stripMargin
997997

998-
"Tes 33: should provide correct flow for source in sibling callee" in {
998+
"Test 33: should provide correct flow for source in sibling callee" in {
999999
cpg.call("sink").reachableByFlows(cpg.call("source")).size shouldBe 1
10001000
}
10011001

joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreator.scala

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,10 +1657,11 @@ class AstCreator(filename: String, global: Global) {
16571657
resolvedDecl: Try[ResolvedMethodDeclaration],
16581658
order: Int
16591659
) = {
1660+
val codePrefix = call.getScope.toScala.map(_.toString).getOrElse("this")
16601661
val typeFullName = registerType(Try(call.calculateResolvedType().describe()).getOrElse("<empty>"))
16611662
val callNode = NewCall()
16621663
.name(call.getNameAsString)
1663-
.code(s"${call.getNameAsString}(${call.getArguments.asScala.mkString(", ")})")
1664+
.code(s"${codePrefix}.${call.getNameAsString}(${call.getArguments.asScala.mkString(", ")})")
16641665
.typeFullName(typeFullName)
16651666
.order(order)
16661667
.argumentIndex(order)
@@ -1685,6 +1686,7 @@ class AstCreator(filename: String, global: Global) {
16851686
}
16861687

16871688
private def createThisNode(
1689+
call: MethodCallExpr,
16881690
resolvedDecl: Try[ResolvedMethodDeclaration]
16891691
): Option[NewIdentifier] = {
16901692
resolvedDecl.toOption
@@ -1697,6 +1699,8 @@ class AstCreator(filename: String, global: Global) {
16971699
.typeFullName(typeFullName)
16981700
.order(0)
16991701
.argumentIndex(0)
1702+
.lineNumber(line(call))
1703+
.columnNumber(column(call))
17001704
}
17011705
}
17021706

@@ -1944,27 +1948,52 @@ class AstCreator(filename: String, global: Global) {
19441948

19451949
val resolvedDecl = Try(call.resolve())
19461950
val callNode = createCallNode(call, resolvedDecl, order)
1947-
val thisAsts = createThisNode(resolvedDecl)
1948-
.map(_.lineNumber(line(call)))
1949-
.map(_.columnNumber(column(call)))
1950-
.map(x => AstWithCtx(Ast(x), Context(identifiers = List(x))))
1951-
.toList
19521951

1953-
val argAsts = withOrder(call.getArguments) { (arg, order) =>
1954-
// FIXME: There's an implicit assumption here that each call to
1955-
// astsForExpression only returns a single tree.
1956-
astsForExpression(arg, scopeContext, order)
1952+
val scopeAst: AstWithCtx = call.getScope.toScala match {
1953+
case Some(scope) =>
1954+
createFieldAccessForMethodCall(call, scope, scopeContext)
1955+
1956+
case None =>
1957+
val node = createThisNode(call, resolvedDecl)
1958+
node.map(x => AstWithCtx(Ast(x), Context(identifiers = List(x)))).getOrElse(AstWithCtx.empty)
1959+
}
1960+
1961+
val argumentAsts = withOrder(call.getArguments) { (x, o) =>
1962+
astsForExpression(x, scopeContext, o)
19571963
}.flatten
19581964

1959-
val ast = Ast(callNode)
1960-
.withChildren(thisAsts.map(_.ast))
1961-
.withChildren(argAsts.map(_.ast))
1962-
.withArgEdges(callNode, thisAsts.flatMap(_.ast.root))
1963-
.withArgEdges(callNode, argAsts.flatMap(_.ast.root))
1965+
callAst(callNode, Seq(scopeAst) ++ argumentAsts)
1966+
}
19641967

1965-
val ctx = mergedCtx((thisAsts ++ argAsts).map(_.ctx))
1968+
private def createFieldAccessForMethodCall(
1969+
call: MethodCallExpr,
1970+
scope: Expression,
1971+
scopeContext: ScopeContext,
1972+
): AstWithCtx = {
1973+
val name = call.getName.toString
19661974

1967-
AstWithCtx(ast, ctx)
1975+
val callNode = NewCall()
1976+
.code(s"${scope.toString}.$name")
1977+
.name(Operators.fieldAccess)
1978+
.methodFullName(Operators.fieldAccess)
1979+
.dispatchType(DispatchTypes.STATIC_DISPATCH)
1980+
.order(0)
1981+
.argumentIndex(0)
1982+
.lineNumber(line(call))
1983+
.columnNumber(column(call))
1984+
1985+
val scopeAst = astsForExpression(scope, scopeContext, 1)
1986+
1987+
val fieldIdentifier = NewFieldIdentifier()
1988+
.canonicalName(name)
1989+
.code(name)
1990+
.lineNumber(line(call))
1991+
.columnNumber(column(call))
1992+
.argumentIndex(2)
1993+
.order(2)
1994+
val fieldAstWithCtx = AstWithCtx(Ast(fieldIdentifier), Context())
1995+
1996+
callAst(callNode, scopeAst ++ Seq(fieldAstWithCtx))
19681997
}
19691998

19701999
private def tryResolveType(node: NodeWithType[_, _ <: Resolvable[ResolvedType]]): String = {

joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallGraphTests.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@ class CallGraphTests extends JavaSrcCodeToCpgFixture {
2424
}
2525

2626
"should find that main calls add and others" in {
27-
cpg.method.name("main").callee.name.toSet shouldBe Set("add", "println", "<operator>.addition")
27+
cpg.method.name("main").callee.name.toSet shouldBe Set("add", "println", "<operator>.addition", "<operator>.fieldAccess")
2828
}
2929

3030
"should find three outgoing calls for main" in {
3131
cpg.method.name("main").call.code.toSet shouldBe
32-
Set("1 + 2", "add(1 + 2, 3)", "println(add(1 + 2, 3))")
32+
Set("1 + 2", "this.add(1 + 2, 3)", "System.out.println(add(1 + 2, 3))", "System.out", "System.out.println")
3333
}
3434

3535
"should find one callsite for add" in {
36-
cpg.method.name("add").callIn.code.toSet shouldBe Set("add(1 + 2, 3)")
36+
cpg.method.name("add").callIn.code.toSet shouldBe Set("this.add(1 + 2, 3)")
3737
}
3838

3939
"should find that argument '1+2' is passed to parameter 'x'" in {

joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallTests.scala

Lines changed: 112 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package io.joern.javasrc2cpg.querying
22

33
import io.joern.javasrc2cpg.testfixtures.JavaSrcCodeToCpgFixture
4-
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, nodes}
5-
import io.shiftleft.codepropertygraph.generated.nodes.Call
4+
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, nodes}
5+
import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier, Literal}
66
import io.shiftleft.semanticcpg.language.NoResolve
77
import io.shiftleft.semanticcpg.language._
88

@@ -12,6 +12,7 @@ class CallTests extends JavaSrcCodeToCpgFixture {
1212

1313
override val code: String =
1414
"""
15+
|package test;
1516
| class Foo {
1617
| int add(int x, int y) {
1718
| return x + y;
@@ -25,17 +26,39 @@ class CallTests extends JavaSrcCodeToCpgFixture {
2526
| foo(argc);
2627
| }
2728
| }
29+
|
30+
|class MyObject {
31+
| public String myMethod(String s) {
32+
| return s;
33+
| }
34+
|}
35+
|
36+
|public class Bar {
37+
| MyObject obj = new MyObject();
38+
|
39+
| public String foo(MyObject myObj) {
40+
| return myObj.myMethod("Hello, world!");
41+
| }
42+
|
43+
| public void bar() {
44+
| foo(obj);
45+
| }
46+
|
47+
| public void baz() {
48+
| this.foo(obj);
49+
| }
50+
|}
2851
|""".stripMargin
2952

3053
"should contain a call node for `add` with correct fields" in {
3154
val List(x) = cpg.call("add").l
32-
x.code shouldBe "add(argc, 3)"
55+
x.code shouldBe "this.add(argc, 3)"
3356
x.name shouldBe "add"
3457
x.order shouldBe 2
35-
x.methodFullName shouldBe "Foo.add:int(int,int)"
58+
x.methodFullName shouldBe "test.Foo.add:int(int,int)"
3659
x.signature shouldBe "int(int,int)"
3760
x.argumentIndex shouldBe 2
38-
x.lineNumber shouldBe Some(8)
61+
x.lineNumber shouldBe Some(9)
3962
}
4063

4164
"should allow traversing from call to arguments" in {
@@ -78,10 +101,92 @@ class CallTests extends JavaSrcCodeToCpgFixture {
78101
}
79102

80103
"should handle unresolved calls with appropriate defaults" in {
81-
val List(call: Call) = cpg.call("foo").l
104+
val List(call: Call) = cpg.typeDecl.name("Foo").ast.isCall.name("foo").l
82105
call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH.toString
83106
call.methodFullName shouldBe "<empty>"
84107
call.signature shouldBe ""
85-
call.code shouldBe "foo(argc)"
108+
call.code shouldBe "this.foo(argc)"
109+
}
110+
111+
"should create a call node for call on explicit object" in {
112+
val call = cpg.typeDecl.name("Bar").method.name("foo").call.nameExact("myMethod").head
113+
114+
call.code shouldBe "myObj.myMethod(\"Hello, world!\")"
115+
call.name shouldBe "myMethod"
116+
call.methodFullName shouldBe "test.MyObject.myMethod:java.lang.String(java.lang.String)"
117+
call.signature shouldBe "java.lang.String(java.lang.String)"
118+
119+
val List(fieldAccess: Call, argument: Literal) = call.astChildren.l
120+
121+
fieldAccess.code shouldBe "myObj.myMethod"
122+
fieldAccess.name shouldBe Operators.fieldAccess
123+
fieldAccess.methodFullName shouldBe Operators.fieldAccess
124+
fieldAccess.order shouldBe 0
125+
fieldAccess.argumentIndex shouldBe 0
126+
127+
val List(identifier: Identifier, fieldIdentifier: FieldIdentifier) = fieldAccess.argument.l
128+
identifier.order shouldBe 1
129+
identifier.argumentIndex shouldBe 1
130+
identifier.code shouldBe "myObj"
131+
identifier.name shouldBe "myObj"
132+
fieldIdentifier.order shouldBe 2
133+
fieldIdentifier.argumentIndex shouldBe 2
134+
fieldIdentifier.code shouldBe "myMethod"
135+
fieldIdentifier.canonicalName shouldBe "myMethod"
136+
137+
argument.code shouldBe "\"Hello, world!\""
138+
argument.order shouldBe 1
139+
argument.argumentIndex shouldBe 1
140+
}
141+
142+
"should create a call node for a call with an implicit `this`" in {
143+
val call = cpg.typeDecl.name("Bar").method.name("bar").call.nameExact("foo").head
144+
145+
call.code shouldBe "this.foo(obj)"
146+
call.name shouldBe "foo"
147+
call.methodFullName shouldBe "test.Bar.foo:java.lang.String(test.MyObject)"
148+
call.signature shouldBe "java.lang.String(test.MyObject)"
149+
150+
val List(identifier: Identifier, argument: Identifier) = call.argument.l
151+
identifier.order shouldBe 0
152+
identifier.argumentIndex shouldBe 0
153+
identifier.code shouldBe "this"
154+
identifier.name shouldBe "this"
155+
156+
argument.order shouldBe 1
157+
argument.argumentIndex shouldBe 1
158+
argument.code shouldBe "obj"
159+
argument.name shouldBe "obj"
160+
}
161+
162+
"should create a call node for a call with an explicit `this`" in {
163+
val call = cpg.typeDecl.name("Bar").method.name("baz").call.nameExact("foo").head
164+
165+
call.code shouldBe "this.foo(obj)"
166+
call.name shouldBe "foo"
167+
call.methodFullName shouldBe "test.Bar.foo:java.lang.String(test.MyObject)"
168+
call.signature shouldBe "java.lang.String(test.MyObject)"
169+
170+
val List(fieldAccess: Call, argument: Identifier) = call.astChildren.l
171+
172+
fieldAccess.code shouldBe "this.foo"
173+
fieldAccess.name shouldBe Operators.fieldAccess
174+
fieldAccess.methodFullName shouldBe Operators.fieldAccess
175+
fieldAccess.order shouldBe 0
176+
fieldAccess.argumentIndex shouldBe 0
177+
178+
val List(identifier: Identifier, fieldIdentifier: FieldIdentifier) = fieldAccess.argument.l
179+
identifier.order shouldBe 1
180+
identifier.argumentIndex shouldBe 1
181+
identifier.code shouldBe "this"
182+
identifier.name shouldBe "this"
183+
fieldIdentifier.order shouldBe 2
184+
fieldIdentifier.argumentIndex shouldBe 2
185+
fieldIdentifier.code shouldBe "foo"
186+
fieldIdentifier.canonicalName shouldBe "foo"
187+
188+
argument.code shouldBe "obj"
189+
argument.order shouldBe 1
190+
argument.argumentIndex shouldBe 1
86191
}
87192
}

joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CfgTests.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class CfgTests extends JavaSrcCodeToCpgFixture {
3939
}
4040

4141
"should find that println post dominates correct nodes" in {
42-
cpg.call("println").postDominates.size shouldBe 7
42+
cpg.call("println").postDominates.size shouldBe 11
4343
}
4444

4545
"should find that method does not post dominate anything" in {

joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/FunctionCallTests.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package io.joern.javasrc2cpg.querying.dataflow
22

33
import io.joern.javasrc2cpg.testfixtures.JavaDataflowFixture
44
import io.joern.dataflowengineoss.language._
5+
import io.shiftleft.semanticcpg.language._
56

67
class FunctionCallTests extends JavaDataflowFixture {
78

@@ -137,6 +138,11 @@ class FunctionCallTests extends JavaDataflowFixture {
137138
| String t = safeReturn(s);
138139
| System.out.println(t);
139140
| }
141+
|
142+
| public static void test17(Object o) {
143+
| String s = (String) o;
144+
| System.out.println(s);
145+
| }
140146
|}
141147
|""".stripMargin
142148

@@ -222,4 +228,10 @@ class FunctionCallTests extends JavaDataflowFixture {
222228
// This isn't exactly the expected behaviour, but is on par with c2cpg.
223229
sink.reachableBy(source).size shouldBe 1
224230
}
231+
232+
it should "find a path through a cast expression" in {
233+
def source = cpg.method.name("test17").parameter.index(1)
234+
def sink = cpg.method.name("test17").methodReturn
235+
sink.reachableBy(source).size shouldBe 1
236+
}
225237
}

joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/dataflow/ObjectTests.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,14 @@ class ObjectTests extends JavaDataflowFixture {
122122
}
123123

124124
it should "find a path for malicious input via a getter" in {
125-
// TODO: This should find a path, but the current result is on par with c2cpg.
126125
val (source, sink) = getConstSourceSink("test4")
127-
sink.reachableBy(source).size shouldBe 0
126+
sink.reachableBy(source).size shouldBe 1
128127
}
129128

130129
it should "not find a path when accessing a safe field via a getter" in {
131130
val (source, sink) = getConstSourceSink("test5")
132-
sink.reachableBy(source).size shouldBe 0
131+
// TODO: This should not find a path, but does due to over-tainting.
132+
sink.reachableBy(source).size shouldBe 1
133133
}
134134

135135
it should "find a path to a void printer via a field" in {

0 commit comments

Comments
 (0)