/* rshd.c - rshd for VxWorks
 *
 * Copyright (C) 1993-2001 Matthew R. Wette -- all rights reserved.
 *
 * $Id: rshd.c,v 1.1 2001/10/08 10:16:45 borkhuis Exp $
 *
 * You may use this program if this copyright notice is kept in tact and
 * you don't try to sell it.
 *
 * To use:
 *	compile to rshd
 *	ld < rshd
 *	rshdInit
 *
 * This program is subject to a memory leak in the VxWorks shell.  Every
 * rsh command will eat a little memory.  Also, if you're using VxWorks 5.3
 * or later, you'll need the SHELL compiled in.
 */
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <vxWorks.h>
#include <ioLib.h>
#include <logLib.h>
#include <remLib.h>
#include <shellLib.h>
#include <sockLib.h>
#include <stdioLib.h>
#include <sysLib.h>
#include <taskLib.h>
#include <in.h>
#include <iv.h>
#include <socket.h>
#include <netinet/tcp.h>

static char const rcsid[] = "$Id: rshd.c,v 1.1 2001/10/08 10:16:45 borkhuis Exp $";

#ifndef RSHD_PORT
#define RSHD_PORT	514
#endif
#ifndef RSHD_PRIO
#define RSHD_PRIO	5
#endif

#define SHELL_CMD	1		/* is a shell command */
#define RESET_CMD	2		/* is a reset command */
#define FEVAL_CMD	3		/* is a function eval command */

int rshdInit(void);
int rshdTask(void);

static int rshd(int sock, struct sockaddr_in*);
static int getstr(int, char*, int, char*);

static int rshdTid = 0;
static int rshdStackSize = 0x8000;
static char rcmd_errmsg[80];


int
rshdInit(void)
{
  rshdTid = taskSpawn("tRshd", RSHD_PRIO, VX_STDIO, rshdStackSize, rshdTask,
		      0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
  if (rshdTid == 0) return ERROR;
  return OK;
}


int
rshdTask(void)
{
  int status, serv_sock, clnt_sock, clen;
  struct sockaddr_in servaddr, clntaddr;

  memset(&servaddr, '\0', sizeof(servaddr));
  memset(&clntaddr, '\0', sizeof(clntaddr));

  /* Open the socket. */
  serv_sock = socket(AF_INET, SOCK_STREAM, 0);
  if (serv_sock == ERROR) return ERROR;

  servaddr.sin_family = AF_INET;
  servaddr.sin_port = htons(RSHD_PORT);

  status = bind(serv_sock, (struct sockaddr *)&servaddr, sizeof(servaddr));
  if (status == ERROR) {
    close(serv_sock);
    logMsg("rshd: bind failed\n", 0, 0, 0, 0, 0, 0);
    return ERROR;
  }
  status = listen(serv_sock, 2);
  if (status == ERROR) {
    logMsg("rshd: listen failed\n", 0, 0, 0, 0, 0, 0);
    close(serv_sock);
    return ERROR;
  }
  while (1) {
    clen = sizeof(clntaddr);
    clnt_sock = accept(serv_sock, (struct sockaddr*)&clntaddr, &clen);
    if (clnt_sock == ERROR) {
      logMsg("rshd: accept failed\n", 0, 0, 0, 0, 0, 0);
      continue;
    }
    status = rshd(clnt_sock, &clntaddr);
    close(clnt_sock);
  }
  return OK;
}


static int
rshd(int sock, struct sockaddr_in *addr)
{
  int n, status, was_stdout, was_stderr;
  int secdport, secdsock, loclport;
  char c, *cmd, clntuser[16], servuser[16], bufr[128];

  /* Get secondary port request. */
  /*alarm(60);*/
  secdport = 0;
  while ((n = read(sock, &c, 1)) == 1) {
    if (c == '\0') break;
    secdport = 10*secdport + c - '0';
  }
  if (secdport != 0) {
    loclport = IPPORT_RESERVED - 1;
    secdsock = rresvport(&loclport);
    addr->sin_port = htons((unsigned short)secdport);
    status = connect(secdsock, (struct sockaddr*)addr, sizeof(*addr));
    if (status == ERROR) {
      logMsg("rshd: connect() on second port %d failed\n", secdport, 0, 0, 0,
	  0, 0);
      close(sock);
      return status;
    }
  }

  /* get the command strings */
  getstr(sock, clntuser, sizeof(clntuser), "clntuser");
  getstr(sock, servuser, sizeof(servuser), "servuser");
  getstr(sock, bufr, sizeof(bufr), "command");
# ifdef DEBUG
  logMsg("rshd: command=[%s]\n", bufr, 0, 0, 0, 0, 0);
# endif

  cmd = bufr;
  while (*cmd == ' ' || *cmd == '\t') cmd++; /* strip leading whitespace */
  if (*cmd++ != '-') {
    /* -- lock shell -- */
    status = shellLock(TRUE);
    if (status == FALSE) {
      logMsg("rshd: failed to lock shell\n", 0, 0, 0, 0, 0, 0);
      write(sock, "!", 1);
      remCurIdGet(servuser, 0);
      sprintf(rcmd_errmsg, "rshd: shell locked by user \"%s\"\n", servuser);
      if (strlen(rcmd_errmsg) > 80) rcmd_errmsg[79] = '\n';
      write(sock, rcmd_errmsg, strlen(rcmd_errmsg));
      goto exit1;
    }
    status = remCurIdSet(clntuser, 0);
    /* -- set I/O -- */
    was_stdout = ioGlobalStdGet(1);
    ioGlobalStdSet(1, sock);
    if (secdport) {
      was_stderr = ioGlobalStdGet(2);
      ioGlobalStdSet(2, secdsock);
    }
    /* -- let the client know we're alive -- */
    if (write(sock, "", 1) != 1) {
      logMsg("rshd: null write to client failed\n", 0, 0, 0, 0, 0, 0);
      goto exit3;
    }
    /* -- execute -- */
    status = execute(bufr);
    if (status == ERROR) {
      logMsg("rshd: execute() failed\n", 0, 0, 0, 0, 0, 0);
      goto exit3;
    }
    break;
  } else {
    /* insert hooks here */
    logMsg("rshd: bad command\n", 0, 0, 0, 0, 0, 0);
    goto exit2;
  }

 exit3:
  /* Reestablish I/O */
  ioGlobalStdSet(STD_OUT, was_stdout);
  if (secdport) ioGlobalStdSet(STD_ERR, was_stderr);

 exit2:
  /* unlock shell */
  status = shellLock(FALSE);

 exit1:
  /* close sockets */
  close(sock);
  if (secdport) close(secdsock);

  if (qCmdPending) qCmdExec();

  return status;
}


static int
getstr(int sock, char *bufr, int cnt, char* errmsg)
{
  int i, n;
  char c;

  bufr[i=0] = '\0';
  while ((n = read(sock, &c, 1)) == 1) {
    bufr[i++] = c;
    if (i == cnt) return ERROR;
    if (c == '\0') return OK;
  }
  if (n == -1) return ERROR;
  return OK;
}


/*
SYNOPSIS
    int rshdInit(void)                  - starts rshdTask
    int rshdTask(void)                  - server task
    int rshd()                          - service a request

  on Unix host:
    rsh <host> <command>

DESCRIPTION
This is an implementation of the remote shell daemon for VxWorks hosts.
It allows users on Unix hosts to run shell commands via the rsh.


ERRORS
If an error occurs rsh will prints a single line error message.

BUGS
Surely some exist.

SEE ALSO: rsh(1), rcmd(1), rshd(8), in.rshd(8)

AUTHOR

Matt Wette <mwette@alumni.caltech.edu>, Jet Propulsion Laboratory
*/

/* --- last line of rshd.c --- */

