Monday, July 26, 2010

Just In Time Compiler for Managed Platform- Part 3: Call a method

Today I'll try to extend the simple JIT compiler to the point where we can call a method from another method.

First lets create a simple java class:

public class Math
{
public static int add(int x, int y)
{
int r;
r= x+y;
return r;
}

public static int SimpleCall2()
{
return add(17, 29);
}
}

Here we call add method from SimpleCall2 method. Our JIT compiler will supply mechanism to handle this. When we compile using java compiler we get following class and method byte code:

public class Math extends java.lang.Object{
public Math();
Signature: ()V
Code:
0: aload_0
1: invokespecial #1;
4: return

public static int add(int, int);
Signature: (II)I
Code:
0: iload_0
1: iload_1
2: iadd
3: istore_2
4: iload_2
5: ireturn

public static int SimpleCall2();
Signature: ()I
Code:
0: bipush 17
2: bipush 29
4: invokestatic #2; //Method add:(II)I
7: ireturn
}


[Note: I do not describe the Java Virtual Machine basics here again- You may look at my article for basic understanding: Home Made Java Virtual Machine]

First we need to extend our Context structure to hold some more values. We should add members at the bottom- otherwise the code we generated so far will be invalid.

struct VMEnvironment
{
ObjectHeap *pObjectHeap;
ClassHeap *pClassHeap;
void **ppHelperMethods;
};

struct Context
{
Variable *stack;
int stackTop;
JavaClass *pClass;
Context *pCallerContext;
VMEnvironment *pVMEnv;
};


Also we want to keep track of native codes we generate. So we need another structure.

struct MethodLink
{
JavaClass *pClass;
method_info_ex *pMethod;
void *pNativeBlock;
};


Now we need a helper class to keep track of the methods we work with. We use a simple string to pointer map. The key string is generated from classname, method name and method desc. So key looks like "Math::add(II)I".

MethodLink* GetMethod(JavaClass *pClass, method_info_ex *pMethod, u4 pc)
{
static CMapStringToPtr methodsMap;

u2 mi=getu2(&pMethod->pCode_attr->code[pc+1]);
char *pConstPool = (char *)pClass->constant_pool[mi];

u2 classIndex = getu2(&pConstPool[1]);
u2 nameAndTypeIndex = getu2(&pConstPool[3]);

//get class at pool index
pConstPool = (char *)pClass->constant_pool[classIndex];

ASSERT(pConstPool[0] == CONSTANT_Class);

u2 ni=getu2(&pConstPool[1]);

CString strClassName;
pClass->GetStringFromConstPool(ni, strClassName);

ClassHeap *pClassHeap = new ClassHeap();

JavaClass *pClassCallee=pClassHeap->GetClass(strClassName);

pConstPool = (char *)pClassCallee->constant_pool[nameAndTypeIndex];
ASSERT(pConstPool[0] == CONSTANT_NameAndType);

u2 name_index = getu2(&pConstPool[1]);
u2 descriptor_index = getu2(&pConstPool[3]);

CString strMethodName, strMethodDesc;
pClassCallee->GetStringFromConstPool(name_index, strMethodName);
pClassCallee->GetStringFromConstPool(descriptor_index, strMethodDesc);

JavaClass *pVirtualClass=pClassCallee;
int nIndex=pClassCallee->GetMethodIndex(strMethodName, strMethodDesc, pVirtualClass);

method_info_ex *pCalleeMethod = &pClassCallee->methods[nIndex];

/*
if( ACC_SUPER & pCalleeMethod->access_flags)
{
pCalleeMethod = pClassCallee->GetSuperClass();
}
*/

CString sign(strClassName+"::"+strMethodName+strMethodDesc);
MethodLink *pLink=NULL;
if(!methodsMap.Lookup(sign, (void *&)pLink))
{
pLink = new MethodLink();
pLink->pClass = pClassCallee;
pLink->pMethod = pCalleeMethod;
pLink->pNativeBlock = NULL;
methodsMap.SetAt(sign, pLink);
}

return pLink;
}


To call a method we do not generate the statck preparation code using machine code for now to keep the things simple. We'll do that after we finish all type of code generation. So from native code we call back to a C++ method that again calls into generated codes-

void CallMethod(MethodLink *pMethodLink, Context *pRE)
{
LOG(_T("CallMethod\n"));

int codeBlockSize = pMethodLink->pMethod->pCode_attr->code_length*2; //todo guess better

int (*NativeBlock)(Context *)=(int (*)(Context *)) VirtualAlloc(NULL, codeBlockSize, MEM_COMMIT, PAGE_EXECUTE_READWRITE);
u1* codes = (u1*) NativeBlock;

int ip =0;

JavaClass *pClass = pMethodLink->pClass;

if(NULL == pMethodLink->pNativeBlock)
{
Compile(pMethodLink->pClass, pMethodLink->pMethod, codes, ip);
pMethodLink->pNativeBlock = codes;
}

CString strName, strDesc;
pMethodLink->pClass->GetStringFromConstPool(pMethodLink->pMethod->name_index, strName);
pMethodLink->pClass->GetStringFromConstPool(pMethodLink->pMethod->descriptor_index, strDesc);

int params=GetMethodParametersStackCount(strDesc)+1;

//invokestatic: we are only dealing with static methods so far

int nDiscardStack =params;
if(pMethodLink->pMethod->access_flags & ACC_NATIVE)
{
}
else
{
nDiscardStack+= pMethodLink->pMethod->pCode_attr->max_locals;
}

pRE->stackTop+=(nDiscardStack-1);
LOG(_T("Invoking method %s%s, \n"), strName, strDesc);

(*NativeBlock)(pRE);

//if returns then get on stack
if(strDesc.Find(_T(")V")) < 0)
{
nDiscardStack--;
if(strDesc.Find(_T(")J")) < 0)
{
}
else
{
nDiscardStack--;
}
}

pRE->stackTop-=nDiscardStack;
LOG(_T("~CallMethod\n"));
}


OK, thats the callbacks we need for now. Now we generate the actual machine code that will use the MethodLink* value to call back to the CallMethod function. To do this we use a function pointer list and store it in the context environment-

#define CALL_METHOD_HELPER_INDEX 0

void* HelperMethods[] = {
CallMethod,
};

Let us now define the InvokeStatic helper method.

void InvokeStatic(JavaClass *pClass, method_info_ex *pMethod, u4 pc, u1* codes, int &ip)
{
MethodLink* pLink = GetMethod(pClass, pMethod, pc);
EmitCallMethod(codes, ip, pLink);
}

void EmitCallMethod(u1* code, int &ip, void* pLinkAddress)
{
//((void (*)(MethodLink *pMethodLink))pRE->pVMEnv->ppHelperMethods[CALL_METHOD_HELPER_INDEX])(pLinkAddress, pRE);
u1 c[] = {
0x8B, 0x45, 0x08, // mov eax,dword ptr [pRE]
0x50, // push eax
0x68, 0x00, 0x00, 0x00, 0x00, // push pLinkAddress
0x8B, 0x4D, 0x08, // mov ecx,dword ptr [pRE]
0x8B, 0x51, 0x10, // mov edx,dword ptr [ecx+10h]
0x8B, 0x42, 0x08, // mov eax,dword ptr [edx+8]
0x8B, 0x08, // mov ecx,dword ptr [eax]
0xFF, 0xD1, // call ecx
0x83, 0xC4, 0x08, // add esp,8
};

memcpy(c+5, &pLinkAddress, 4);
memcpy(&code[ip], c, sizeof(c));
ip+=sizeof(c);
}


To compile the methods we define a function that generates machine code for java byte codes. This function does not handle branch instructions right now. To handle branch we probably need two pass- since we would not know the exact address during first pass. So, here is a large while loop to do basic things:

u4 Compile(JavaClass *pClass, method_info_ex *pMethod, u1 *codes, int &ip)
{
if(pMethod->access_flags & ACC_NATIVE)
{
return 1;
}

Prolog(codes, ip);

u4 pc=0;
u1 *bc=pMethod->pCode_attr->code;

i4 error=0;

CString strMethod;
pClass->GetStringFromConstPool(pMethod->name_index, strMethod);

i4 index=0;
while(pMethod->pCode_attr->code_length>pc)
{
LOG(_T("Opcode = %s\n"),OpcodeDesc[(u1)bc[pc]]);

switch(bc[pc])
{
case nop:
pc++;
break;

case bipush:// 16 /*(0x10)*/
BiPush(codes, ip, (u1)bc[pc+1]);
pc+=2;
break;

case iload_0: //26 Load int from local variable 0
ILoad_0(codes, ip);
pc++;
break;

case iload_1: //27 Load int from local variable 1
ILoad_1(codes, ip);
pc++;
break;
case iload_2: //28 Load int from local variable 2
ILoad_2(codes, ip);
pc++;
break;
case iload_3: //29 Load int from local variable 3
ILoad_3(codes, ip);
pc++;
break;

case istore_2: // 61 /*(0x3d) */
IStore_2(codes, ip);
pc++;
break;

case iadd: //96
IAdd(codes, ip);
pc++;
break;

case invokestatic:// 184
InvokeStatic(pClass, pMethod, pc, codes, ip);
pc+=3;
break;
case ireturn: //172 (0xac)
IReturn(codes, ip);
pc++;
break;

default:
error=1;
break;
}

if(error) break;
}

Return0(codes, ip);
Epilog(codes, ip);

return error;
}


OK, we are now ready to test out code:

int main()
{
Context *pRE = new Context();;
pRE->stack = new Variable[STACK_SIZE];
pRE->stackTop = 0;
memset(pRE->stack, 0, sizeof(Variable)*STACK_SIZE);

pRE->pVMEnv = new VMEnvironment();
pRE->pVMEnv->pClassHeap = new ClassHeap();
pRE->pVMEnv->pObjectHeap = new ObjectHeap();

pRE->pVMEnv->ppHelperMethods = HelperMethods;

ClassHeap* pClsHeap = pRE->pVMEnv->pClassHeap;

JavaClass jc;
pClsHeap->LoadClass("Math", &jc);
JavaClass *pVirtualClass =&jc, *pClass1 = &jc;

int mindex=pClass1->GetMethodIndex(_T("SimpleCall2"),_T("()I"),pVirtualClass);

method_info_ex *pMethod = &pVirtualClass->methods[mindex];

MethodLink *pMethodLink = new MethodLink();
pMethodLink->pClass = pVirtualClass;
pMethodLink->pMethod = pMethod;

((void (*)(MethodLink *pMethodLink, Context *pRE))pRE->pVMEnv->ppHelperMethods[CALL_METHOD_HELPER_INDEX])(pMethodLink, pRE);
LOG(_T("Return Value = %d"), pRE->stack[0].intValue);

return 0;
}

Do you see value 46 on the stack as return value? Cool!